diff --git a/experimental/robust_segvit/README.md b/experimental/robust_segvit/README.md index 315313b7b..9e0218e5d 100644 --- a/experimental/robust_segvit/README.md +++ b/experimental/robust_segvit/README.md @@ -1,16 +1,25 @@ # Robust segvit -*Robust_segvit* is a codebase to evaluate the robustness of semantic segmentation models. +**Robust_segvit** is a codebase to evaluate the robustness of semantic segmentation models. The code is built on top of [uncertainty_baselines](https://github.com/google/uncertainty-baselines) and [Scenic](https://github.com/google-research/scenic). -Robust_segvit is developed in [JAX](https://github.com/google/jax) and uses [Flax](https://github.com/google/flax), [uncertainty_baselines](https://github.com/google/uncertainty-baselines) and [Scenic](https://github.com/google-research/scenic). +## Installation +Robust_segvit is developed in [JAX](https://github.com/google/jax)/[Flax](https://github.com/google/flax). -## Code structure -See uncertainty_baselines/google/experimental/cityscapes. +To run the code:
+1. Install [uncertainty_baselines](https://github.com/google/uncertainty-baselines).
+2. Install [Scenic](https://github.com/google-research/scenic).
+3. Follow the instructions for a toy run in [./run_deterministic_mac.sh](). +## Datasets +The experiment configurations for the different datasets are in: -## Cityscapes +- configs/cityscapes: Cityscapes dataset.
+- configs/ade20k_ind: ADE20k_ind dataset.
+- configs/street_hazards: Street Hazards dataset.
-We investigate the performance of different reliability methods on image segmentation tasks.
+## Comments: +- The checkpoint used for finetuning is the same the original segmenter model: [vit_large_patch16_384](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) -[x] configs/cityscapes: contains experiment configurations for the cityscapes dataset.
+## Citing work: +If you reference this code, please cite [our paper](https://github.com/google/uncertainty-baselines).
\ No newline at end of file diff --git a/experimental/robust_segvit/___init__.py b/experimental/robust_segvit/___init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/experimental/robust_segvit/checkpoint_utils.py b/experimental/robust_segvit/checkpoint_utils.py new file mode 100644 index 000000000..eeb59d054 --- /dev/null +++ b/experimental/robust_segvit/checkpoint_utils.py @@ -0,0 +1,71 @@ +# load checkpoints +from scenic.train_lib_deprecated import train_utils +from pretrainer_utils import convert_torch_to_jax_checkpoint # local file import from experimental.robust_segvit +from scenic.train_lib_deprecated import pretrain_utils +from pretrainer_utils import convert_vision_transformer_to_scenic # local file import from experimental.robust_segvit + + +def load_checkpoints_eval(config, model, train_state, workdir): + checkpoint_configs = config.get('checkpoint_configs', False) + if checkpoint_configs: + # Load torch weights + if 'torch' in checkpoint_configs.checkpoint_format: + + bb_train_state = convert_torch_to_jax_checkpoint( + checkpoint_path=checkpoint_configs.checkpoint_path, + config=checkpoint_configs) + + train_state = model.init_backbone_from_train_state( + train_state, + bb_train_state, + config, + checkpoint_configs + ) + del bb_train_state + + # Load weights in checkpoint_path or workdir + else: + checkpoint_path = checkpoint_configs.get('checkpoint_path', workdir) + train_state, _ = train_utils.restore_checkpoint( + checkpoint_path, train_state) + return train_state + + +def load_checkpoints_backbone(config, model, train_state, workdir): + del workdir + # TODO(kellybuchanan): check out partial loader in + # https://github.com/google/uncertainty-baselines/commit/083b1dcc52bb1964f8917d15552ece8848d582ae# + restored_model_cfg = config.get('pretrained_backbone_configs') + + # Load pretrained backbone + if restored_model_cfg.checkpoint_format in ('ub', 'big_vision', 'scenic'): + # load params from checkpoint + bb_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint( + checkpoint_path=restored_model_cfg.checkpoint_path, + convert_to_linen=False) + + train_state = model.init_backbone_from_train_state( + train_state, + bb_train_state, + config, + restored_model_cfg, + model_prefix_path=['backbone']) + # Free unnecessary memory. + del bb_train_state + # Loader from scenic + elif restored_model_cfg.checkpoint_format in ('vision_transformer'): + # load params from checkpoint + bb_train_state = convert_vision_transformer_to_scenic(checkpoint_path=restored_model_cfg.checkpoint_path, convert_to_linen=False) + + train_state = model.init_backbone_from_train_state( + train_state, + bb_train_state, + config, + restored_model_cfg, + model_prefix_path=['backbone']) + + # Free unnecessary memory. + del bb_train_state + else: + raise NotImplementedError('') + return train_state diff --git a/experimental/robust_segvit/configs/ade20k_ind/be.py b/experimental/robust_segvit/configs/ade20k_ind/be.py index 9e149362a..e0410ec30 100644 --- a/experimental/robust_segvit/configs/ade20k_ind/be.py +++ b/experimental/robust_segvit/configs/ade20k_ind/be.py @@ -22,6 +22,8 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_FINE_TRAIN_SIZE = 2975 _CITYSCAPES_COARSE_TRAIN_SIZE = 19998 @@ -40,21 +42,24 @@ # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (640, 640) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', # pylint: disable=g-long-lambda + } @@ -180,9 +185,25 @@ def get_config(runlocal=''): config.eval_label_shift = False config.model.input_shape = target_size + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + if runlocal: config.count_flops = False config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size config.batch_size = 8 config.num_training_epochs = 5 config.warmup_steps = 0 diff --git a/experimental/robust_segvit/configs/ade20k_ind/be_eval.py b/experimental/robust_segvit/configs/ade20k_ind/be_eval.py index f70890536..df0aa2d3a 100644 --- a/experimental/robust_segvit/configs/ade20k_ind/be_eval.py +++ b/experimental/robust_segvit/configs/ade20k_ind/be_eval.py @@ -20,6 +20,8 @@ # pylint: enable=line-too-long import ml_collections +import datetime +import os _CITYSCAPES_FINE_TRAIN_SIZE = 2975 _CITYSCAPES_COARSE_TRAIN_SIZE = 19998 @@ -42,14 +44,14 @@ STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' -EXPERIMENTID = '43838358-2' +EXPERIMENTID = '45349725-1' target_size = (640, 640) # Upstream CHECKPOINT_PATHS = { - ('ub', 'L', 16, None, 'token', '43838358-2'): - 'gs://ub-ekb/checkpoints_to_upload/ade20k/43838358-2', + ('ub', 'L', 16, None, 'token', '45349725-1'): + 'gs://ub-checkpoints/45349725-ade20k_ind_segmenter_be/1', } @@ -162,17 +164,33 @@ def get_config(runlocal=''): config.eval_mode = True config.eval_configs = ml_collections.ConfigDict() config.eval_configs.mode = 'standard' - config.eval_covariate_shift = True - config.eval_label_shift = True config.model.input_shape = target_size + config.eval_configs.store_logits = False # Eval parameters for robustness config.eval_label_shift = True config.eval_covariate_shift = True config.eval_robustness_configs = ml_collections.ConfigDict() config.eval_robustness_configs.auc_online = True - config.eval_robustness_configs.method_name = 'msp' - config.eval_robustness_configs.num_top_k = 5 + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. if runlocal: config.count_flops = False @@ -183,12 +201,6 @@ def get_config(runlocal=''): config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' config.num_train_examples = TRAIN_SAMPLES - else: - # Load checkpoint - config.checkpoint_configs = ml_collections.ConfigDict() - config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN - config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH - config.checkpoint_configs.classifier = 'token' return config diff --git a/experimental/robust_segvit/configs/ade20k_ind/deterministic.py b/experimental/robust_segvit/configs/ade20k_ind/deterministic.py index 9472fd60b..33883665a 100644 --- a/experimental/robust_segvit/configs/ade20k_ind/deterministic.py +++ b/experimental/robust_segvit/configs/ade20k_ind/deterministic.py @@ -22,6 +22,8 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_FINE_TRAIN_SIZE = 2975 _CITYSCAPES_COARSE_TRAIN_SIZE = 19998 @@ -40,21 +42,23 @@ # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (640, 640) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', } @@ -174,9 +178,25 @@ def get_config(runlocal=''): config.eval_label_shift = False config.model.input_shape = target_size + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + if runlocal: config.count_flops = False config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size config.batch_size = 8 config.num_training_epochs = 5 config.warmup_steps = 0 diff --git a/experimental/robust_segvit/configs/ade20k_ind/deterministic_eval.py b/experimental/robust_segvit/configs/ade20k_ind/deterministic_eval.py new file mode 100644 index 000000000..e5f8ff97e --- /dev/null +++ b/experimental/robust_segvit/configs/ade20k_ind/deterministic_eval.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Eval segmenter model on ade20k_ind. + + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE +} + +# Model specs. +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (640, 640) + +CHECKPOINT_ORIGIN = 'ub' +EXPERIMENTID = '45373386-1' + +# Upstream +CHECKPOINT_PATHS = { + ('ub', 'L', 16, None, 'token', '45373386-1'): + 'gs://ub-checkpoints/45373386-ade20k_ind_deterministic/1', +} + + +CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'ade20k_ind_deterministic_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'ade20k_ind' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = True + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.model.input_shape = target_size + config.eval_configs.store_logits = False + + # Eval parameters for robustness + config.eval_label_shift = True + config.eval_covariate_shift = True + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + return hyper.product([]) + diff --git a/experimental/robust_segvit/configs/ade20k_ind/deterministic_seeds.py b/experimental/robust_segvit/configs/ade20k_ind/deterministic_seeds.py new file mode 100644 index 000000000..d7e7468ef --- /dev/null +++ b/experimental/robust_segvit/configs/ade20k_ind/deterministic_seeds.py @@ -0,0 +1,253 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on ade20k_ind. + +Compare performance across seeds. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE +} + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (640, 640) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'ade20k_ind_deterministic_seeds' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'ade20k_ind' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 3e-5 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = False + config.eval_covariate_shift = False + config.eval_label_shift = False + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for doing grid search.""" + + seeds = hyper.sweep('config.rng_seed', range(0, 5)) + + + return hyper.product([seeds]) diff --git a/experimental/robust_segvit/configs/ade20k_ind/gp.py b/experimental/robust_segvit/configs/ade20k_ind/gp.py index fe3c3186d..bd022cbc6 100644 --- a/experimental/robust_segvit/configs/ade20k_ind/gp.py +++ b/experimental/robust_segvit/configs/ade20k_ind/gp.py @@ -22,6 +22,8 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_FINE_TRAIN_SIZE = 2975 _CITYSCAPES_COARSE_TRAIN_SIZE = 19998 @@ -40,21 +42,23 @@ # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (640, 640) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', } @@ -187,9 +191,25 @@ def get_config(runlocal=''): config.eval_label_shift = False config.model.input_shape = target_size + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + if runlocal: config.count_flops = False config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size config.batch_size = 8 config.num_training_epochs = 5 config.warmup_steps = 0 diff --git a/experimental/robust_segvit/configs/ade20k_ind/gp_eval.py b/experimental/robust_segvit/configs/ade20k_ind/gp_eval.py new file mode 100644 index 000000000..f505258d7 --- /dev/null +++ b/experimental/robust_segvit/configs/ade20k_ind/gp_eval.py @@ -0,0 +1,255 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Eval segmenter model on ade20k_ind. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE +} + +# Model specs. +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (640, 640) + +CHECKPOINT_ORIGIN = 'ub' +EXPERIMENTID='45350699-1' +# Upstream +CHECKPOINT_PATHS = { + ('ub', 'L', 16, None, 'token', '45350699-1'): + 'gs://ub-checkpoints/45350699-ade20k_ind_segmenter_gp/1', +} + + +CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'ade20k_ind_segmenter_gp_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'ade20k_ind' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'gp' + + # GP layer params + config.model.decoder.gp_layer = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1. + # Disable momentum in order to use exact covariance update for finetuning. + # Disable to allow exact cov update. + config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99 + config.model.decoder.mean_field_factor = 1. + # Additional params + config.model.decoder.gp_layer.normalize_input = True + config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1. + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 3e-5 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = True + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.model.input_shape = target_size + config.eval_configs.store_logits = False + + # Eval parameters for robustness + config.eval_label_shift = True + config.eval_covariate_shift = True + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + return hyper.product([]) + diff --git a/experimental/robust_segvit/configs/ade20k_ind/gp_seeds.py b/experimental/robust_segvit/configs/ade20k_ind/gp_seeds.py new file mode 100644 index 000000000..5c69a5728 --- /dev/null +++ b/experimental/robust_segvit/configs/ade20k_ind/gp_seeds.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on ade20k_ind. + +Compare performance across seeds. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE +} + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (640, 640) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'ade20k_ind_segmenter_gp_seeds' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'ade20k_ind' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'gp' + + # GP layer params + config.model.decoder.gp_layer = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1. + # Disable momentum in order to use exact covariance update for finetuning. + # Disable to allow exact cov update. + config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99 + config.model.decoder.mean_field_factor = 9. + # Additional params + config.model.decoder.gp_layer.normalize_input = True + config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1. + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 3e-5 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = False + config.eval_covariate_shift = False + config.eval_label_shift = False + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for doing grid search.""" + + seeds = hyper.sweep('config.rng_seed', range(0, 5)) + + return hyper.product([seeds]) + + diff --git a/experimental/robust_segvit/configs/ade20k_ind/het.py b/experimental/robust_segvit/configs/ade20k_ind/het.py index 52ed0f2dd..b85e701ff 100644 --- a/experimental/robust_segvit/configs/ade20k_ind/het.py +++ b/experimental/robust_segvit/configs/ade20k_ind/het.py @@ -22,6 +22,8 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_FINE_TRAIN_SIZE = 2975 _CITYSCAPES_COARSE_TRAIN_SIZE = 19998 @@ -40,21 +42,23 @@ # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (640, 640) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', } @@ -187,9 +191,25 @@ def get_config(runlocal=''): config.eval_label_shift = False config.model.input_shape = target_size + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + if runlocal: config.count_flops = False config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size config.batch_size = 8 config.num_training_epochs = 5 config.warmup_steps = 0 diff --git a/experimental/robust_segvit/configs/ade20k_ind/het_eval.py b/experimental/robust_segvit/configs/ade20k_ind/het_eval.py index 66f06f8e4..a8385762d 100644 --- a/experimental/robust_segvit/configs/ade20k_ind/het_eval.py +++ b/experimental/robust_segvit/configs/ade20k_ind/het_eval.py @@ -22,6 +22,8 @@ # pylint: enable=line-too-long import ml_collections +import datetime +import os _CITYSCAPES_FINE_TRAIN_SIZE = 2975 _CITYSCAPES_COARSE_TRAIN_SIZE = 19998 @@ -44,14 +46,14 @@ STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' -EXPERIMENTID = '43838062-14' +EXPERIMENTID = '45350817-1' target_size = (640, 640) # Upstream CHECKPOINT_PATHS = { - ('ub', 'L', 16, None, 'token', '43838062-14'): - 'gs://ub-ekb/checkpoints_to_upload/ade20k/43838062-14', + ('ub', 'L', 16, None, 'token', '45350817-1'): + 'gs://ub-checkpoints/45350817-ade20k_ind_segmenter_het_hyper/1', } @@ -116,12 +118,12 @@ def get_config(runlocal=''): # Het layer params # temp: wide sweep [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0] - config.model.decoder.temperature = 2.0 + config.model.decoder.temperature = 1.0 # efficient low rank approx ~ FxK where K is the classes. False for K<20. config.model.decoder.param_efficient = False # F as a low rank approx of KxK matrix has num_factors: # imagenet~15, jft~50, cifar~6, cityscapes~sweep(5-10). - config.model.decoder.num_factors = 10 + config.model.decoder.num_factors = 5 # mc_samples: use as much as can be afforded, ideally > 10. config.model.decoder.mc_samples = 1000 config.model.decoder.return_locs = False @@ -172,17 +174,34 @@ def get_config(runlocal=''): config.eval_mode = True config.eval_configs = ml_collections.ConfigDict() config.eval_configs.mode = 'standard' - config.eval_covariate_shift = True - config.eval_label_shift = True config.model.input_shape = target_size + config.eval_configs.store_logits = False # Eval parameters for robustness config.eval_label_shift = True config.eval_covariate_shift = True config.eval_robustness_configs = ml_collections.ConfigDict() config.eval_robustness_configs.auc_online = True - config.eval_robustness_configs.method_name = 'msp' - config.eval_robustness_configs.num_top_k = 5 + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + if runlocal: config.count_flops = False @@ -193,12 +212,6 @@ def get_config(runlocal=''): config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' config.num_train_examples = TRAIN_SAMPLES - else: - # Load checkpoint - config.checkpoint_configs = ml_collections.ConfigDict() - config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN - config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH - config.checkpoint_configs.classifier = 'token' return config diff --git a/experimental/robust_segvit/configs/ade20k_ind/toy_model.py b/experimental/robust_segvit/configs/ade20k_ind/toy_model.py index cfc274317..55dbdfdd7 100644 --- a/experimental/robust_segvit/configs/ade20k_ind/toy_model.py +++ b/experimental/robust_segvit/configs/ade20k_ind/toy_model.py @@ -20,6 +20,8 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_FINE_TRAIN_SIZE = 2975 _CITYSCAPES_COARSE_TRAIN_SIZE = 19998 @@ -135,8 +137,23 @@ def get_config(runlocal=''): config.eval_mode = False config.eval_configs = ml_collections.ConfigDict() config.eval_configs.mode = 'standard' - config.eval_covariate_shift = False - config.eval_label_shift = False + config.eval_covariate_shift = True + config.eval_label_shift = True + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. if runlocal: config.count_flops = False diff --git a/experimental/robust_segvit/configs/cityscapes/be.py b/experimental/robust_segvit/configs/cityscapes/be.py index 6f123e6db..2bf6586f6 100644 --- a/experimental/robust_segvit/configs/cityscapes/be.py +++ b/experimental/robust_segvit/configs/cityscapes/be.py @@ -22,27 +22,31 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_TRAIN_SIZE = 2975 _CITYSCAPES_TRAIN_SIZE_SPLIT = 146 # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (768, 768) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', } @@ -157,6 +161,23 @@ def get_config(runlocal=''): config.eval_configs.mode = 'standard' config.eval_covariate_shift = True config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + if runlocal: config.count_flops = False diff --git a/experimental/robust_segvit/configs/cityscapes/be_eval.py b/experimental/robust_segvit/configs/cityscapes/be_eval.py index 38f30d6b0..72b6b63e5 100644 --- a/experimental/robust_segvit/configs/cityscapes/be_eval.py +++ b/experimental/robust_segvit/configs/cityscapes/be_eval.py @@ -14,14 +14,14 @@ # limitations under the License. # pylint: disable=line-too-long -r"""Train segmenter model on cityscapes dataset. - -Compare performance from deterministic upstream checkpoints. +r"""Evaluate segmenter_be model on cityscapes dataset. """ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_TRAIN_SIZE = 2975 _CITYSCAPES_TRAIN_SIZE_SPLIT = 146 @@ -33,12 +33,12 @@ RESNET_SIZE = None CLASSIFIER = 'token' target_size = (768, 768) -EXPERIMENTID = '43838585-16' +EXPERIMENTID = '45338505-1' # Upstream CHECKPOINT_PATHS = { - ('ub', 'L', 16, None, 'token', '43838585-16'): - 'gs://ub-ekb/checkpoints_to_upload/cityscapes/43838585-16', + ('ub', 'L', 16, None, 'token', '45338505-1'): + 'gs://ub-checkpoints/45338505-cityscapes_segmenter_be/1', } @@ -66,11 +66,15 @@ def get_config(runlocal=''): config.experiment_name = 'cityscapes_segmenter_be_eval' # Dataset. - config.dataset_name = 'cityscapes' + config.dataset_name = 'robust_segvit_segmentation' config.dataset_configs = ml_collections.ConfigDict() config.dataset_configs.target_size = (1024, 2048) config.dataset_configs.train_split = 'train' - config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 # Model. config.model_name = 'segvit' @@ -143,8 +147,7 @@ def get_config(runlocal=''): config.eval_configs = ml_collections.ConfigDict() config.eval_configs.mode = 'segmm' config.eval_configs.window_stride = 512 - config.eval_covariate_shift = True - config.eval_label_shift = True + config.eval_configs.store_logits = False config.model.input_shape = target_size # Eval parameters for robustness @@ -152,8 +155,25 @@ def get_config(runlocal=''): config.eval_covariate_shift = True config.eval_robustness_configs = ml_collections.ConfigDict() config.eval_robustness_configs.auc_online = True - config.eval_robustness_configs.method_name = 'msp' - config.eval_robustness_configs.num_top_k = 5 + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. if runlocal: config.count_flops = False @@ -164,12 +184,6 @@ def get_config(runlocal=''): config.dataset_configs.train_split = 'train[:5%]' config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( 'batch_size') - else: - # Load checkpoint - config.checkpoint_configs = ml_collections.ConfigDict() - config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN - config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH - config.checkpoint_configs.classifier = 'token' return config diff --git a/experimental/robust_segvit/configs/cityscapes/deterministic.py b/experimental/robust_segvit/configs/cityscapes/deterministic.py index e3d249b94..56b9c290c 100644 --- a/experimental/robust_segvit/configs/cityscapes/deterministic.py +++ b/experimental/robust_segvit/configs/cityscapes/deterministic.py @@ -22,27 +22,31 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_TRAIN_SIZE = 2975 _CITYSCAPES_TRAIN_SIZE_SPLIT = 146 # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (768, 768) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', } @@ -70,11 +74,15 @@ def get_config(runlocal=''): config.experiment_name = 'cityscapes_segmenter_pretrained' # Dataset. - config.dataset_name = 'cityscapes' + config.dataset_name = 'robust_segvit_segmentation' config.dataset_configs = ml_collections.ConfigDict() config.dataset_configs.target_size = target_size config.dataset_configs.train_split = 'train' - config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 # Model. config.model_name = 'segvit' @@ -151,6 +159,22 @@ def get_config(runlocal=''): config.eval_configs.mode = 'standard' config.eval_covariate_shift = True config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. if runlocal: config.count_flops = False @@ -211,8 +235,8 @@ def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, def get_sweep(hyper): """Defines the hyper-parameters sweeps for doing grid search.""" checkpoints = hyper.chainit([ - checkpoint(hyper, 'big_vision', 'L', 16, None, 'token', - 'i21k+imagenet2012'), + checkpoint(hyper, 'vision_transformer', 'L', 16, None, 'token', + 'augreg+i21k+imagenet2012'), ]) epochs = hyper.sweep('config.num_training_epochs', [50, 100, 300]) diff --git a/experimental/robust_segvit/configs/cityscapes/deterministic_eval.py b/experimental/robust_segvit/configs/cityscapes/deterministic_eval.py new file mode 100644 index 000000000..363761709 --- /dev/null +++ b/experimental/robust_segvit/configs/cityscapes/deterministic_eval.py @@ -0,0 +1,232 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Eval segmenter model trained on cityscapes dataset. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_TRAIN_SIZE = 2975 +_CITYSCAPES_TRAIN_SIZE_SPLIT = 146 + +# Model specs. +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (768, 768) + + +CHECKPOINT_ORIGIN = 'ub' +EXPERIMENTID = '45337813-1' + +# Upstream +CHECKPOINT_PATHS = { + ('ub', 'L', 16, None, 'token', '45337813-1'): + 'gs://ub-checkpoints/45337813-cityscapes_segmenter_pretrained/1', +} + + +CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + + +def get_config(runlocal=''): + """Returns the configuration for Cityscapes segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'cityscapes_segmenter_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = (1024, 2048) + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 64 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref( + 'batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = True + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'segmm' + config.eval_configs.window_stride = 512 + config.eval_configs.store_logits = False + config.model.input_shape = target_size + + # Eval parameters for robustness + config.eval_label_shift = True + config.eval_covariate_shift = True + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.target_size = (128, 128) + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = 'train[:5%]' + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( + 'batch_size') + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{'size': (stride, stride)}])) + + if vit_size == 'B': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + return hyper.product([]) diff --git a/experimental/robust_segvit/configs/cityscapes/deterministic_seeds.py b/experimental/robust_segvit/configs/cityscapes/deterministic_seeds.py new file mode 100644 index 000000000..c3c0de27b --- /dev/null +++ b/experimental/robust_segvit/configs/cityscapes/deterministic_seeds.py @@ -0,0 +1,242 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on cityscapes dataset. + +Compare performance across seeds. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_TRAIN_SIZE = 2975 +_CITYSCAPES_TRAIN_SIZE_SPLIT = 146 + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (768, 768) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + + +def get_config(runlocal=''): + """Returns the configuration for Cityscapes segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'cityscapes_segmenter_seeds' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 64 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref( + 'batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = False + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.target_size = (128, 128) + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = 'train[:5%]' + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( + 'batch_size') + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{'size': (stride, stride)}])) + + if vit_size == 'B': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for doing grid search.""" + + seeds = hyper.sweep('config.rng_seed', range(0, 5)) + + + return hyper.product([seeds]) + diff --git a/experimental/robust_segvit/configs/cityscapes/gp.py b/experimental/robust_segvit/configs/cityscapes/gp.py index 29205f705..e9ac3c654 100644 --- a/experimental/robust_segvit/configs/cityscapes/gp.py +++ b/experimental/robust_segvit/configs/cityscapes/gp.py @@ -22,27 +22,31 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_TRAIN_SIZE = 2975 _CITYSCAPES_TRAIN_SIZE_SPLIT = 146 # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (768, 768) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', } @@ -70,11 +74,15 @@ def get_config(runlocal=''): config.experiment_name = 'cityscapes_segmenter_gp_hyper' # Dataset. - config.dataset_name = 'cityscapes' + config.dataset_name = 'robust_segvit_segmentation' config.dataset_configs = ml_collections.ConfigDict() config.dataset_configs.target_size = target_size config.dataset_configs.train_split = 'train' - config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 # Model. config.model_name = 'segvit' @@ -164,6 +172,22 @@ def get_config(runlocal=''): config.eval_configs.mode = 'standard' config.eval_covariate_shift = True config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. if runlocal: config.count_flops = False diff --git a/experimental/robust_segvit/configs/cityscapes/gp_eval.py b/experimental/robust_segvit/configs/cityscapes/gp_eval.py new file mode 100644 index 000000000..d5836360f --- /dev/null +++ b/experimental/robust_segvit/configs/cityscapes/gp_eval.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Evaluate segmenter_gp model on cityscapes dataset. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_TRAIN_SIZE = 2975 +_CITYSCAPES_TRAIN_SIZE_SPLIT = 146 + +# Model specs. +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (768, 768) + +CHECKPOINT_ORIGIN = 'ub' +EXPERIMENTID = '45338722-1' + +# Upstream +CHECKPOINT_PATHS = { + ('ub', 'L', 16, None, 'token', '45338722-1'): + 'gs://ub-checkpoints/45338722-cityscapes_segmenter_gp_hyper/1', +} + + +CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + + +def get_config(runlocal=''): + """Returns the configuration for Cityscapes segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'cityscapes_segmenter_gp_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = (1024, 2048) + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'gp' + + # GP layer params + config.model.decoder.gp_layer = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1. + # Disable momentum in order to use exact covariance update for finetuning. + # Disable to allow exact cov update. + config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99 + config.model.decoder.mean_field_factor = 1. + # Additional params + config.model.decoder.gp_layer.normalize_input = True + config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1. + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 64 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref( + 'batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = True + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'segmm' + config.eval_configs.window_stride = 512 + config.eval_configs.store_logits = False + config.model.input_shape = target_size + + # Eval parameters for robustness + config.eval_label_shift = True + config.eval_covariate_shift = True + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.target_size = (128, 128) + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = 'train[:5%]' + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( + 'batch_size') + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{'size': (stride, stride)}])) + + if vit_size == 'B': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + return hyper.product([]) diff --git a/experimental/robust_segvit/configs/cityscapes/gp_seeds.py b/experimental/robust_segvit/configs/cityscapes/gp_seeds.py new file mode 100644 index 000000000..f43d397b0 --- /dev/null +++ b/experimental/robust_segvit/configs/cityscapes/gp_seeds.py @@ -0,0 +1,260 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on cityscapes dataset. + +Compare performance across seeds. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_TRAIN_SIZE = 2975 +_CITYSCAPES_TRAIN_SIZE_SPLIT = 146 + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (768, 768) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + + +def get_config(runlocal=''): + """Returns the configuration for Cityscapes segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'cityscapes_segmenter_gp_seeds' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'gp' + + # GP layer params + config.model.decoder.gp_layer = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1. + # Disable momentum in order to use exact covariance update for finetuning. + # Disable to allow exact cov update. + config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99 + config.model.decoder.mean_field_factor = 3. + # Additional params + config.model.decoder.gp_layer.normalize_input = True + config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1. + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 64 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref( + 'batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = False + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.target_size = (128, 128) + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = 'train[:5%]' + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( + 'batch_size') + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{'size': (stride, stride)}])) + + if vit_size == 'B': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for doing grid search.""" + + parameters = [ + hyper.sweep('config.model.decoder.gp_layer.normalize_input', + [True, False]), + hyper.sweep('config.model.decoder.mean_field_factor', + hyper.discrete(range(1, 10))), + hyper.sweep('config.model.decoder.gp_layer.hidden_kwargs.feature_scale', + [1.0, 2.0]), + ] + + return hyper.product(parameters) diff --git a/experimental/robust_segvit/configs/cityscapes/het.py b/experimental/robust_segvit/configs/cityscapes/het.py index 45e263afe..7d6b8f552 100644 --- a/experimental/robust_segvit/configs/cityscapes/het.py +++ b/experimental/robust_segvit/configs/cityscapes/het.py @@ -22,27 +22,31 @@ # pylint: enable=line-too-long import ml_collections +import os +import datetime _CITYSCAPES_TRAIN_SIZE = 2975 _CITYSCAPES_TRAIN_SIZE_SPLIT = 146 # Model specs. LOAD_PRETRAINED_BACKBONE = True -BACKBONE_ORIGIN = 'big_vision' +BACKBONE_ORIGIN = 'vision_transformer' VIT_SIZE = 'L' STRIDE = 16 RESNET_SIZE = None CLASSIFIER = 'token' target_size = (768, 768) -UPSTREAM_TASK = 'i21k+imagenet2012' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' # Upstream MODEL_PATHS = { # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 - ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'): - 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', } @@ -70,11 +74,15 @@ def get_config(runlocal=''): config.experiment_name = 'cityscapes_segmenter_het_base' # Dataset. - config.dataset_name = 'cityscapes' + config.dataset_name = 'robust_segvit_segmentation' config.dataset_configs = ml_collections.ConfigDict() config.dataset_configs.target_size = target_size config.dataset_configs.train_split = 'train' - config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 # Model. config.model_name = 'segvit' @@ -165,6 +173,22 @@ def get_config(runlocal=''): config.eval_configs.mode = 'standard' config.eval_covariate_shift = True config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. if runlocal: config.count_flops = False diff --git a/experimental/robust_segvit/configs/cityscapes/het_eval.py b/experimental/robust_segvit/configs/cityscapes/het_eval.py new file mode 100644 index 000000000..4038e5688 --- /dev/null +++ b/experimental/robust_segvit/configs/cityscapes/het_eval.py @@ -0,0 +1,249 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on cityscapes dataset. + +Compare performance from deterministic upstream checkpoints. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_TRAIN_SIZE = 2975 +_CITYSCAPES_TRAIN_SIZE_SPLIT = 146 + +# Model specs. +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (768, 768) + + +CHECKPOINT_ORIGIN = 'ub' +EXPERIMENTID = '45338794-1' + +# Upstream +CHECKPOINT_PATHS = { + ('ub', 'L', 16, None, 'token', '45338794-1'): + 'gs://ub-checkpoints/45338794-cityscapes_segmenter_het_base/1', +} + + +CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + + +def get_config(runlocal=''): + """Returns the configuration for Cityscapes segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'cityscapes_segmenter_het_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = (1024, 2048) + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'het' + + # Het layer params + # temp: wide sweep [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0] + config.model.decoder.temperature = 1.0 + # efficient low rank approx ~ FxK where K is the classes. False for K<20. + config.model.decoder.param_efficient = False + # F as a low rank approx of KxK matrix has num_factors: + # imagenet~15, jft~50, cifar~6, cityscapes~sweep(5-10). + config.model.decoder.num_factors = 5 + # mc_samples: use as much as can be afforded, ideally > 10. + config.model.decoder.mc_samples = 1000 + config.model.decoder.return_locs = False + # turn on to run an approx on KHW x KHW instead of KxK. + config.model.decoder.share_samples_across_batch = False + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 64 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref( + 'batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = True + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'segmm' + config.eval_configs.window_stride = 512 + config.eval_configs.store_logits = False + config.model.input_shape = target_size + + # Eval parameters for robustness + config.eval_label_shift = True + config.eval_covariate_shift = True + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.target_size = (128, 128) + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = 'train[:5%]' + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( + 'batch_size') + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{'size': (stride, stride)}])) + + if vit_size == 'B': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + return hyper.product([]) + diff --git a/experimental/robust_segvit/configs/cityscapes/torch_eval.py b/experimental/robust_segvit/configs/cityscapes/torch_eval.py new file mode 100644 index 000000000..532e9fe21 --- /dev/null +++ b/experimental/robust_segvit/configs/cityscapes/torch_eval.py @@ -0,0 +1,213 @@ +import ml_collections +import os +import datetime + +_CITYSCAPES_TRAIN_SIZE = 2975 +_CITYSCAPES_TRAIN_SIZE_SPLIT = 146 + +# Model specs. +CHECKPOINT_ORIGIN = 'torch-segmm' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (768, 768) +EXPERIMENTID = 'torch-segmm-1' + +# Upstream +CHECKPOINT_PATHS = { + ('torch-segmm', 'L', 16, None, 'token', 'torch-segmm-1'): + 'gs://ub-ekb/seg_l16_linear/checkpoint_model.npy', +} + + +CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + + +def get_config(runlocal=''): + """Returns the configuration for Cityscapes segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'cityscapes_segmenter_torch_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = (1024, 2048) + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 64 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref( + 'batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = True + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'segmm' + config.eval_configs.window_stride = 512 + config.model.input_shape = target_size + + # Eval parameters for robustness + config.eval_label_shift = True + config.eval_covariate_shift = True + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'nmlogit' + config.eval_robustness_configs.num_top_k = 1 + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH + config.checkpoint_configs.classifier = 'token' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.target_size = (128, 128) + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = 'train[:5%]' + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( + 'batch_size') + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{'size': (stride, stride)}])) + + if vit_size == 'B': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append( + hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append( + hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append( + hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append( + hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.checkpoint_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.checkpoint_configs.checkpoint_path', [ + CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the parameters used to compare multiple metrics during eval.""" + + checkpoints = hyper.chainit([ + checkpoint(hyper, 'ub', 'L', 16, None, 'token', 'torch-segmm-1'), + ]) + + return hyper.product([checkpoints]) diff --git a/experimental/robust_segvit/configs/cityscapes/toy_model.py b/experimental/robust_segvit/configs/cityscapes/toy_model.py index 6a5e50e80..ba2147c5d 100644 --- a/experimental/robust_segvit/configs/cityscapes/toy_model.py +++ b/experimental/robust_segvit/configs/cityscapes/toy_model.py @@ -20,7 +20,8 @@ # pylint: enable=line-too-long import ml_collections - +import os +import datetime batch_size = 128 _CITYSCAPES_TRAIN_SIZE_SPLIT = 146 @@ -43,11 +44,15 @@ def get_config(runlocal=''): config.experiment_name = 'cityscapes_segmenter_toy_model' # Dataset. - config.dataset_name = 'cityscapes' + config.dataset_name = 'robust_segvit_segmentation' config.dataset_configs = ml_collections.ConfigDict() config.dataset_configs.target_size = target_size config.dataset_configs.train_split = 'train[:5%]' - config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 # Model. config.model_name = 'segvit' @@ -118,6 +123,21 @@ def get_config(runlocal=''): config.eval_covariate_shift = True config.eval_label_shift = True + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + if runlocal: config.count_flops = False diff --git a/experimental/robust_segvit/configs/cityscapes/toy_model_eval.py b/experimental/robust_segvit/configs/cityscapes/toy_model_eval.py new file mode 100644 index 000000000..83965106e --- /dev/null +++ b/experimental/robust_segvit/configs/cityscapes/toy_model_eval.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Eval toy segmenter model on cityscapes. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +batch_size = 8 +_CITYSCAPES_TRAIN_SIZE_SPLIT = 16 + +# Model spec. +STRIDE = 4 +mlp_dim = 2 +num_heads = 1 +num_layers = 1 +hidden_size = 1 +target_size = (128, 128) + +# Upstream +CHECKPOINT_ORIGIN = 'ub' +VIT_SIZE = 'debug' +RESNET_SIZE = None +CLASSIFIER = 'token' +EXPERIMENTID = 'city_toy' + +CHECKPOINT_PATHS = { + ('ub', 'debug', 4, None, 'token', 'city_toy'): + 'ub-ekb/segmenter/cityscapes/toy_model/toy_model', +} + + +CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + +def get_config(runlocal=''): + """Returns the configuration for Cityscapes segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'cityscapes_segmenter_toy_model' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_split = 'train[:16]' + config.dataset_configs.validation_split = 'validation[:16]' + config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = 'gap' + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(2) + config.batch_size = batch_size + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref( + 'batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 0 + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # init not included + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = True + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_covariate_shift = True + config.eval_label_shift = True + config.eval_configs.store_logits = False + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + + return config + + +def get_sweep(hyper): + return hyper.product([]) diff --git a/experimental/robust_segvit/configs/street_hazards/be.py b/experimental/robust_segvit/configs/street_hazards/be.py new file mode 100644 index 000000000..2f6b9f7a9 --- /dev/null +++ b/experimental/robust_segvit/configs/street_hazards/be.py @@ -0,0 +1,265 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on street_hazards. + +Compare performance from deterministic upstream checkpoints. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 +_STREET_HAZARDS_TRAIN_SIZE = 5125 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE, + 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE, + +} + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (720, 720) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', # pylint: disable=g-long-lambda + +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for street hazards segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'street_hazards_segmenter_be' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'street_hazards' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit_be' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear_be' + + # BE variables + config.model.backbone.ens_size = 3 + config.model.backbone.random_sign_init = -0.5 + config.model.backbone.be_layers = (22, 23) + config.fast_weight_lr_multiplier = 1.0 + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = False + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for grid search.""" + + random_sign_init = hyper.sweep('config.model.backbone.random_sign_init', + [-0.5, 0.5]) + fast_weight_lr_multiplier = hyper.sweep('config.fast_weight_lr_multiplier', + [0.5, 1.0, 2.0]) + + return hyper.product([random_sign_init, fast_weight_lr_multiplier]) diff --git a/experimental/robust_segvit/configs/street_hazards/deterministic.py b/experimental/robust_segvit/configs/street_hazards/deterministic.py new file mode 100644 index 000000000..0c044eae9 --- /dev/null +++ b/experimental/robust_segvit/configs/street_hazards/deterministic.py @@ -0,0 +1,257 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on street_hazards. + +Compare performance from deterministic upstream checkpoints. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 +_STREET_HAZARDS_TRAIN_SIZE = 5125 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE, + 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE + +} + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (720, 720) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 with augreg + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'street_hazards_deterministic' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'street_hazards' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = False + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for doing grid search.""" + + learning_rate = hyper.sweep('config.lr_configs.base_learning_rate', + [1e-4, 3e-4, 3e-5, 1e-5]) + + epochs = hyper.sweep('config.num_training_epochs', [100, 50, 200, 250]) + + return hyper.product([learning_rate, epochs]) diff --git a/experimental/robust_segvit/configs/street_hazards/deterministic_eval.py b/experimental/robust_segvit/configs/street_hazards/deterministic_eval.py new file mode 100644 index 000000000..17e4493a5 --- /dev/null +++ b/experimental/robust_segvit/configs/street_hazards/deterministic_eval.py @@ -0,0 +1,247 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Evaluate segmenter model on street_hazards. + + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 +_STREET_HAZARDS_TRAIN_SIZE = 5125 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE, + 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE + +} + +# Model specs. +target_size = (720, 720) + +LOAD_PRETRAINED_BACKBONE = True +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' +target_size = (720, 720) + +CHECKPOINT_ORIGIN = 'ub' +EXPERIMENTID='det_run1' + +# Upstream +MODEL_PATHS = { + ('ub', 'L', 16, None, 'token', 'det_run1'): + 'gs://ub-ekb/segmenter/street_hazards/deterministic/deterministic_2022-09-27-07-32-08', +} + + +MODEL_PATH = MODEL_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'street_hazards_deterministic_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'street_hazards' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = MODEL_PATH + config.checkpoint_configs.classifier = 'token' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = True + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + config.eval_configs.store_logits = False + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + return hyper.product([]) diff --git a/experimental/robust_segvit/configs/street_hazards/gp.py b/experimental/robust_segvit/configs/street_hazards/gp.py new file mode 100644 index 000000000..ba4c8bc4e --- /dev/null +++ b/experimental/robust_segvit/configs/street_hazards/gp.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on street_hazards. + + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 +_STREET_HAZARDS_TRAIN_SIZE = 5125 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE, + 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE + +} + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (720, 720) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 with augreg + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'street_hazards_gp_hyper' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'street_hazards' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'gp' + + # GP layer params + config.model.decoder.gp_layer = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1. + # Disable momentum in order to use exact covariance update for finetuning. + # Disable to allow exact cov update. + config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99 + config.model.decoder.mean_field_factor = 1. + # Additional params + config.model.decoder.gp_layer.normalize_input = True + config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1. + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = False + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for doing grid search.""" + + parameters = [ + hyper.sweep('config.model.decoder.gp_layer.normalize_input', + [True, False]), + hyper.sweep('config.model.decoder.mean_field_factor', + hyper.discrete(range(1, 10))), + hyper.sweep('config.model.decoder.gp_layer.hidden_kwargs.feature_scale', + [1.0, 2.0]), + ] + + return hyper.product(parameters) \ No newline at end of file diff --git a/experimental/robust_segvit/configs/street_hazards/gp_eval.py b/experimental/robust_segvit/configs/street_hazards/gp_eval.py new file mode 100644 index 000000000..6ab27e8f5 --- /dev/null +++ b/experimental/robust_segvit/configs/street_hazards/gp_eval.py @@ -0,0 +1,260 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Evaluate segmenter_gp model on street_hazards. + + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 +_STREET_HAZARDS_TRAIN_SIZE = 5125 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE, + 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE + +} + +# Model specs. +target_size = (720, 720) + +LOAD_PRETRAINED_BACKBONE = True +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' +target_size = (720, 720) + +CHECKPOINT_ORIGIN = 'ub' +EXPERIMENTID='gp_run1' + +# Upstream +MODEL_PATHS = { + ('ub', 'L', 16, None, 'token', 'gp_run1'): + 'gs://ub-ekb/segmenter/street_hazards/gp/gp_2022-10-03-15-05-54', +} + + +MODEL_PATH = MODEL_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE, + RESNET_SIZE, CLASSIFIER, EXPERIMENTID)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for ADE20k_ind segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'street_hazards_gp_eval' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'street_hazards' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.0 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'gp' + + # GP layer params + config.model.decoder.gp_layer = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1. + # Disable momentum in order to use exact covariance update for finetuning. + # Disable to allow exact cov update. + config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99 + config.model.decoder.mean_field_factor = 1. + # Additional params + config.model.decoder.gp_layer.normalize_input = True + config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict() + config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1. + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # Load checkpoint + config.checkpoint_configs = ml_collections.ConfigDict() + config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN + config.checkpoint_configs.checkpoint_path = MODEL_PATH + config.checkpoint_configs.classifier = 'token' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = True + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + config.eval_configs.store_logits = False + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + return hyper.product([]) diff --git a/experimental/robust_segvit/configs/street_hazards/het.py b/experimental/robust_segvit/configs/street_hazards/het.py new file mode 100644 index 000000000..5493e2ad3 --- /dev/null +++ b/experimental/robust_segvit/configs/street_hazards/het.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train segmenter model on ade20k_ind. + +Compare performance from deterministic upstream checkpoints. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 +_STREET_HAZARDS_TRAIN_SIZE = 5125 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE, + 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE, +} + +# Model specs. +LOAD_PRETRAINED_BACKBONE = True +BACKBONE_ORIGIN = 'vision_transformer' +VIT_SIZE = 'L' +STRIDE = 16 +RESNET_SIZE = None +CLASSIFIER = 'token' +target_size = (720, 720) +UPSTREAM_TASK = 'augreg+i21k+imagenet2012' + + +# Upstream +MODEL_PATHS = { + + # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 + ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'): + 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz', + ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'): + 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', +} + + +MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE, + CLASSIFIER, UPSTREAM_TASK)] + +if VIT_SIZE == 'B': + mlp_dim = 3072 + num_heads = 12 + num_layers = 12 + hidden_size = 768 +elif VIT_SIZE == 'L': + mlp_dim = 4096 + num_heads = 16 + num_layers = 24 + hidden_size = 1024 + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for street hazards segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'street_hazards_ind_segmenter_het_hyper' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + + config.dataset_configs.train_split = 'train' + config.dataset_configs.name = 'street_hazards' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = CLASSIFIER + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'het' + + # Het layer params + # temp: wide sweep [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0] + config.model.decoder.temperature = 1.0 + # efficient low rank approx ~ FxK where K is the classes. False for K<20. + config.model.decoder.param_efficient = False + # F as a low rank approx of KxK matrix has num_factors: + # imagenet~15, jft~50, cifar~6, cityscapes~sweep(5-10). + config.model.decoder.num_factors = 5 + # mc_samples: use as much as can be afforded, ideally > 10. + config.model.decoder.mc_samples = 1000 + config.model.decoder.return_locs = False + # turn on to run an approx on KHW x KHW instead of KxK. + config.model.decoder.share_samples_across_batch = False + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(100) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch') + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 1e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # load pretrained backbone + config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE + config.pretrained_backbone_configs = ml_collections.ConfigDict() + config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN + config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH + config.pretrained_backbone_configs.token_init = True + config.pretrained_backbone_configs.classifier = 'token' + config.pretrained_backbone_configs.backbone_type = 'vit' + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_mode = False + config.eval_covariate_shift = True + config.eval_label_shift = True + config.model.input_shape = target_size + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'msp' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.dataset_configs.train_target_size = (128, 128) + config.model.input_shape = config.dataset_configs.train_target_size + config.batch_size = 8 + config.num_training_epochs = 5 + config.warmup_steps = 0 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + + return config + + +def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task): + """Defines checkpoints for sweep.""" + overwrites = [] + if resnet_size is not None: + raise NotImplementedError('') + else: + overwrites.append( + hyper.sweep('config.model.patches', [{ + 'size': (stride, stride) + }])) + + if vit_size == 'B': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768])) + elif vit_size == 'L': + overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096])) + overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16])) + overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24])) + overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024])) + else: + raise NotImplementedError('') + + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_format', + [backbone_origin])) + overwrites.append( + hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [ + MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size, + classifier, upstream_task)] + ])) + + return hyper.product(overwrites) + + +def get_sweep(hyper): + """Defines the hyper-parameters sweeps for doing grid search.""" + parameters = [ + hyper.sweep('config.model.decoder.num_factors', + hyper.discrete([5, 10, 20, 50])), + hyper.sweep('config.model.decoder.temperature', + [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0]), + hyper.sweep('config.model.decoder.share_samples_across_batch', + [True, False]), + hyper.sweep('config.model.decoder.param_efficient', + [True, False]), + ] + + return hyper.product(parameters) diff --git a/experimental/robust_segvit/configs/street_hazards/toy_model.py b/experimental/robust_segvit/configs/street_hazards/toy_model.py new file mode 100644 index 000000000..0d251d923 --- /dev/null +++ b/experimental/robust_segvit/configs/street_hazards/toy_model.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train toy model on toy street_hazards dataset calling the robust_segvit codebase. + +""" +# pylint: enable=line-too-long + +import ml_collections +import os +import datetime + +_CITYSCAPES_FINE_TRAIN_SIZE = 2975 +_CITYSCAPES_COARSE_TRAIN_SIZE = 19998 + +_ADE20K_TRAIN_SIZE = 20210 +_PASCAL_VOC_TRAIN_SIZE = 10582 +_PASCAL_CONTEXT_TRAIN_SIZE = 4998 +_STREET_HAZARDS_TRAIN_SIZE = 5125 + +TRAIN_SIZES = { + 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE, + 'ade20k': _ADE20K_TRAIN_SIZE, + 'ade20k_ind': _ADE20K_TRAIN_SIZE, + 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE, + 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE, + 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE +} + +# Model spec. +STRIDE = 4 +mlp_dim = 2 +num_heads = 1 +num_layers = 1 +hidden_size = 1 +target_size = (512, 512) + +TRAIN_SAMPLES = 32 + + +def get_config(runlocal=''): + """Returns the configuration for street_hazards segmentation.""" + + runlocal = bool(runlocal) + + config = ml_collections.ConfigDict() + config.experiment_name = 'street_hazards_segmenter_ind_toy_model' + + # Dataset. + config.dataset_name = 'robust_segvit_segmentation' + config.dataset_configs = ml_collections.ConfigDict() + config.dataset_configs.target_size = target_size + config.dataset_configs.train_target_size = config.dataset_configs.get_ref( + 'target_size') + config.dataset_configs.denoise = None + config.dataset_configs.use_timestep = 0 + config.dataset_configs.train_split = 'train' + config.dataset_configs.validation_split = 'validation' + config.dataset_configs.name = 'street_hazards' + config.dataset_configs.dataset_name = '' # ood name flag to write in eval. + + # Model. + config.model_name = 'segvit' + config.model = ml_collections.ConfigDict() + + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (STRIDE, STRIDE) + + config.model.backbone = ml_collections.ConfigDict() + config.model.backbone.type = 'vit' + config.model.backbone.mlp_dim = mlp_dim + config.model.backbone.num_heads = num_heads + config.model.backbone.num_layers = num_layers + config.model.backbone.hidden_size = hidden_size + config.model.backbone.dropout_rate = 0.1 + config.model.backbone.attention_dropout_rate = 0.0 + config.model.backbone.classifier = 'token' + + # Decoder + config.model.decoder = ml_collections.ConfigDict() + config.model.decoder.type = 'linear' + + # Training. + config.trainer_name = 'segvit_trainer' + config.optimizer = 'adam' + config.optimizer_configs = ml_collections.ConfigDict() + config.l2_decay_factor = 0.0 + config.max_grad_norm = 1.0 + config.label_smoothing = None + config.num_training_epochs = ml_collections.FieldReference(2) + config.batch_size = 32 + config.rng_seed = 0 + config.focal_loss_gamma = 0.0 + + # Learning rate. + config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name) + config.steps_per_epoch = config.get_ref( + 'num_train_examples') // config.get_ref('batch_size') + # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine. + config.lr_configs = ml_collections.ConfigDict() + config.lr_configs.learning_rate_schedule = 'compound' + config.lr_configs.factors = 'constant * cosine_decay * linear_warmup' + config.lr_configs.warmup_steps = 0 + config.lr_configs.steps_per_cycle = config.get_ref( + 'num_training_epochs') * config.get_ref('steps_per_epoch') + config.lr_configs.base_learning_rate = 5e-4 + + # model and data dtype + config.model_dtype_str = 'float32' + config.data_dtype_str = 'float32' + + # init not included + + # Logging. + config.write_summary = True + config.write_xm_measurements = True # write XM measurements + config.xprof = False # Profile using xprof. + config.checkpoint = True # Do checkpointing. + config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch') + + config.debug_train = False # Debug mode during training. + config.debug_eval = False # Debug mode during eval. + config.log_eval_steps = 1 * config.get_ref('steps_per_epoch') + + # Evaluation. + config.eval_mode = False + config.eval_configs = ml_collections.ConfigDict() + config.eval_configs.mode = 'standard' + config.eval_covariate_shift = True + config.eval_label_shift = True + + config.eval_robustness_configs = ml_collections.ConfigDict() + config.eval_robustness_configs.auc_online = True + config.eval_robustness_configs.method_name = 'mlogit' + + # wandb.ai configurations. + config.use_wandb = False + config.wandb_dir = 'wandb' + config.wandb_project = 'rdl-debug' + config.wandb_entity = 'ekellbuch' + config.wandb_exp_name = None # Give experiment a name. + config.wandb_exp_name = ( + os.path.splitext(os.path.basename(__file__))[0] + '_' + + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) + config.wandb_exp_group = None # Give experiment a group name. + + if runlocal: + config.count_flops = False + config.batch_size = 8 + config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]' + config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]' + config.num_train_examples = TRAIN_SAMPLES + return config + + +def get_sweep(hyper): + return hyper.product([]) diff --git a/experimental/robust_segvit/custom_models.py b/experimental/robust_segvit/custom_models.py index f4b2b68a0..0cab33f38 100644 --- a/experimental/robust_segvit/custom_models.py +++ b/experimental/robust_segvit/custom_models.py @@ -391,6 +391,11 @@ def global_unc_metrics_fn( assert isinstance(all_unc_confusion_mats, list) # List of eval batches. cm = np.sum(all_unc_confusion_mats, axis=0) # Sum over eval batches. + if cm.ndim == 2: # [batch_size, 4] + pass + elif cm.ndim == 3: # [num_devices, batch_size per device, 4] + cm = np.sum(cm, axis=0) # sum over devices + assert cm.ndim == 2, ('Expecting uncertainty confusion matrix to have shape ' '[batch_size, 4], got ' f'{cm.shape}.') diff --git a/experimental/robust_segvit/custom_segmentation_trainer.py b/experimental/robust_segvit/custom_segmentation_trainer.py index 086ab1ae9..a7d19690e 100644 --- a/experimental/robust_segvit/custom_segmentation_trainer.py +++ b/experimental/robust_segvit/custom_segmentation_trainer.py @@ -44,12 +44,18 @@ from ensemble_utils import log_average_softmax_probs # local file import from experimental.robust_segvit from inference import process_batch # local file import from experimental.robust_segvit from ood_metrics import get_ood_metrics # local file import from experimental.robust_segvit -from ood_metrics import get_ood_score # local file import from experimental.robust_segvit from pretrainer_utils import convert_torch_to_jax_checkpoint # local file import from experimental.robust_segvit +from pretrainer_utils import convert_vision_transformer_to_scenic # local file import from experimental.robust_segvit from uncertainty_metrics import get_uncertainty_confusion_matrix # local file import from experimental.robust_segvit - +from checkpoint_utils import load_checkpoints_eval +from checkpoint_utils import load_checkpoints_backbone +import h5py +import os import resource import sys +import robustness_metrics as rm +from metrics_multihost import ComputeOODAUCMetric +from metrics_multihost import host_all_gather_metrics Batch = Dict[str, jnp.ndarray] MetricFn = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], @@ -126,6 +132,7 @@ def evaluate(train_state: train_utils.TrainState, global_metrics_fn: Any, global_unc_metrics_fn: Optional[Any], prefix: str = 'valid', + workdir: str = '', ) -> Dict[str, Any]: """Model evaluator. @@ -160,17 +167,54 @@ def evaluate(train_state: train_utils.TrainState, # Evaluate global metrics on one of the hosts (lead_host), but given # intermediate values collected from all hosts. - for _ in range(steps_per_eval): + # setup calibration evaluation + ece_num_bins = config.get('ece_num_bins', 15) + ece_metric = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins)._metric + calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False)._metric + + # store logits + store_logits = config.eval_configs.get('store_logits', False) + + if store_logits: + store_logits_fname = os.path.join(workdir, "{}_{}_val.h5py".format(prefix,"logits")) + f = h5py.File(store_logits_fname, 'w', libver='latest') + f.swmr_mode = True # single write multi-read + input_shape = dataset.meta_data['input_shape'][1:3] + num_classes = dataset.meta_data['num_classes'] + num_eval_examples = int(steps_per_eval * config.batch_size) + logits_out = f.create_dataset('logits', (num_eval_examples,) + input_shape + (num_classes,)) + inputs_out = f.create_dataset('inputs', (num_eval_examples,) + input_shape + (3,)) + labels_out = f.create_dataset('labels', (num_eval_examples,) + input_shape) + + for step_ in range(steps_per_eval): eval_batch = next(dataset.valid_iter) e_batch, e_logits, e_metrics, confusion_matrix, unc_confusion_matrix = eval_step_pmapped( train_state=train_state, batch=eval_batch) eval_metrics.append(train_utils.unreplicate_and_get(e_metrics)) + + probs = jax.nn.softmax(e_logits, axis=-1) + # updates on each host separately + ece_metric.update_state(labels=e_batch['label'], probabilities=probs, sample_weight=e_batch['batch_mask']) + y_pred = jnp.argmax(probs, axis=-1) # predicted label indices + confidence = jnp.max(probs, axis=-1) # confidence score for predicted labels + calib_auc.update_state(y_true=e_batch['label'], y_pred=y_pred, confidence=confidence, sample_weight=e_batch['batch_mask']) + if lead_host and global_metrics_fn is not None: # Collect data to be sent for computing global metrics. eval_all_confusion_mats.append(to_cpu(confusion_matrix, all_gather=True)) eval_all_unc_confusion_mats.append( to_cpu(unc_confusion_matrix, all_gather=True)) + if store_logits: + start_idx = step_ * config.batch_size + end_idx = start_idx + config.batch_size + logits_out[start_idx:end_idx] = e_logits + inputs_out[start_idx:end_idx] = e_batch['inputs'] + labels_out[start_idx:end_idx] = e_batch['label'] + + if store_logits: + f.close() + # Compute global metrics eval_global_metrics_summary = {} if lead_host and global_metrics_fn is not None: @@ -190,6 +234,17 @@ def evaluate(train_state: train_utils.TrainState, prefix=prefix, ) + del e_metrics, eval_batch, eval_metrics, eval_global_metrics_summary + del eval_all_confusion_mats + del eval_all_unc_confusion_mats + + # Gather uncertainty metrics from all hosts and write value: + ece_metric = host_all_gather_metrics(ece_metric) + calib_auc = host_all_gather_metrics(calib_auc) + writer.write_scalars(step=step, scalars={'{}_ece'.format(prefix): ece_metric.result(), + '{}_calib_auc'.format(prefix): calib_auc.result(), + } ) + # Visualize val predictions for one batch: if lead_host: # in eval_step we do not use all_gather in batch or logits @@ -206,10 +261,7 @@ def evaluate(train_state: train_utils.TrainState, writer.flush() # Free some memory - del eval_metrics - del eval_global_metrics_summary - del eval_all_confusion_mats - del eval_all_unc_confusion_mats + del example_viz, images, e_batch, e_predictions, e_logits, logits return eval_summary @@ -222,6 +274,7 @@ def evaluate_ood( writer: metric_writers.MetricWriter, lead_host: Any, prefix: str = 'valid', + workdir: str ='', **kwargs, ) -> Dict[str, Any]: """Model evaluator. @@ -250,89 +303,55 @@ def evaluate_ood( auc_online = kwargs.pop('auc_online', False) + # store logits + store_logits = config.eval_configs.get('store_logits', False) + + if store_logits: + store_logits_fname = os.path.join(workdir, "{}_{}_val.h5py".format(prefix,"logits")) + f = h5py.File(store_logits_fname, 'w', libver='latest') + f.swmr_mode = True # single write multi-read + input_shape = dataset.meta_data['input_shape'][1:3] + num_classes = dataset.meta_data['num_classes'] + num_eval_examples = int(steps_per_eval * config.batch_size) + logits_out = f.create_dataset('logits', (num_eval_examples,) + input_shape + (num_classes,)) + inputs_out = f.create_dataset('inputs', (num_eval_examples,) + input_shape + (3,)) + labels_out = f.create_dataset('labels', (num_eval_examples,) + input_shape) + if auc_online: # TODO(kellybuchanan): check split of data across devices. # initialize metrics: ideally in each device in each host/process/machine - # keras initializes one metric in each host because it runs in cpu - # so we need to convert to jax to run metrics in each device in each host - - auc_pr = tf.keras.metrics.AUC(curve='PR') - auc_roc = tf.keras.metrics.AUC(curve='ROC') + # keras initializes one metric in each host because it runs in cpu. + # so we need to convert the function to run metrics in each device/host. + auc_pr = ComputeOODAUCMetric(curve='PR', num_thresholds=100) + auc_roc = ComputeOODAUCMetric(curve='ROC', num_thresholds=100) # Loop through each machine: - for _ in range(steps_per_eval): + for step_ in range(steps_per_eval): eval_batch = next(dataset.valid_iter) e_batch, e_logits = eval_step_pmapped( train_state=train_state, batch=eval_batch) - # In eval_step_pmapped we have not used all gather, so each metric is in - # each device and we should be able to compute devices separately - - ood_score = get_ood_score(e_logits, **kwargs) - - auc_roc.update_state( - e_batch['label'], ood_score, sample_weight=e_batch['batch_mask']) - auc_pr.update_state( - e_batch['label'], ood_score, sample_weight=e_batch['batch_mask']) - - # How to communicate metrics across hosts? - # Ideally we can collect auc_metrics per host, merge them, compute result. - # However, we cannot pass arbitraty class. - # jax which doesn't work with arbitrary objects - # Here we write a custom merge_state as in tf.keras.metrics - # by pulling states from tf.keras obj, combining them and putting them back - # into a keras object using list of host's auc_roc objects. - - def keras_auc_to_arrays(keras_auc_object): - """Pull out arrays from keras roc object.""" - # The thresholds used are determinisitc, so we need not store them. - tp = jnp.asarray(keras_auc_object.true_positives) - fp = jnp.asarray(keras_auc_object.false_positives) - tn = jnp.asarray(keras_auc_object.true_negatives) - fn = jnp.asarray(keras_auc_object.false_negatives) - return tp, fp, tn, fn - - def arrays_to_keras_auc(tp, fp, tn, fn, keras_auc_object): - """Assign confusion matrix arrays to a keras_auc_object.""" - keras_auc_object.true_positives.assign(tp) - keras_auc_object.false_positives.assign(fp) - keras_auc_object.true_negatives.assign(tn) - keras_auc_object.false_negatives.assign(fn) - return keras_auc_object - - auc_roc_state = keras_auc_to_arrays(auc_roc) - auc_pr_state = keras_auc_to_arrays(auc_pr) - - def combine_states(all_auc_states): - # jax can take in trees of arrays, tuple is considered a tree so we can - # unpack it here. - # each array here has dimensions #host x shape - - all_tp, all_fp, all_tn, all_fn = all_auc_states - - assert all_tp.shape == (jax.process_count(), 200) - assert all_fp.shape == (jax.process_count(), 200) - assert all_tn.shape == (jax.process_count(), 200) - assert all_fn.shape == (jax.process_count(), 200) - - tp = jnp.sum(all_tp, 0) - fp = jnp.sum(all_fp, 0) - tn = jnp.sum(all_tn, 0) - fn = jnp.sum(all_fn, 0) - - return tp, fp, tn, fn - - # Gather the data across all hosts. - all_auc_roc_states = multihost_utils.process_allgather(auc_roc_state) - all_auc_pr_states = multihost_utils.process_allgather(auc_pr_state) - - # Below we pick the first device. - auc_roc = arrays_to_keras_auc(*combine_states(all_auc_roc_states), auc_roc) - auc_pr = arrays_to_keras_auc(*combine_states(all_auc_pr_states), auc_pr) - - eval_summary = {'auroc': float(auc_roc.result().numpy()), - 'auprc': float(auc_pr.result().numpy()), + if store_logits: + start_idx = step_ * config.batch_size + end_idx = start_idx + config.batch_size + logits_out[start_idx:end_idx] = e_logits + inputs_out[start_idx:end_idx] = e_batch['inputs'] + labels_out[start_idx:end_idx] = e_batch['labels'] + + # In eval_step_pmapped we have not used all gather, so each metric is in each device + # and we should be able to compute metrics in devices separately. + auc_pr.calculate_and_update_scores(logits=e_logits, label=e_batch['label'], + sample_weight=e_batch['batch_mask'], **kwargs) + auc_roc.calculate_and_update_scores(logits=e_logits, label=e_batch['label'], + sample_weight=e_batch['batch_mask'], **kwargs) + + if store_logits: + f.close() + + eval_summary = {'auroc': auc_roc.gather_metrics(), + 'auprc': auc_pr.gather_metrics(), } + else: eval_logits = [] eval_ood_masks = [] @@ -344,7 +363,7 @@ def combine_states(all_auc_states): e_batch, e_logits = eval_step_pmapped( train_state=train_state, batch=eval_batch) - # Store all logits in cpu + # Store all logits in cpu: if lead_host: e_batch = to_cpu(e_batch, all_gather=False) e_logits = to_cpu(e_logits, all_gather=False) @@ -363,6 +382,7 @@ def combine_states(all_auc_states): ood_mask=eval_ood_labels, weights=eval_ood_masks, **kwargs) + ############### LOG EVAL SUMMARY ############### writer.write_scalars( step, { @@ -498,12 +518,14 @@ def training_loss_fn(params): logits) # batch_size x h x w x num_classes metrics = metrics_fn(logits, batch) - new_train_state = train_state.replace( # pytype: disable=attribute-error + logits = jnp.argmax(logits, axis=-1) + + train_state = train_state.replace( # pytype: disable=attribute-error global_step=step + 1, optimizer=new_optimizer, model_state=new_model_state, rng=new_rng) - return new_train_state, metrics, lr, jnp.argmax(logits, axis=-1) + return train_state, metrics, lr, logits def eval_step( @@ -588,10 +610,12 @@ def eval_step( # Collect predictions and batches from all hosts. # use all_gather to copy and replicate across all hosts # we skip doing this for batch and logits to save memory + # unless we want to store the logits # predictions = jnp.argmax(logits, axis=-1) # predictions = jax.lax.all_gather(predictions, 'batch') - # logits = jax.lax.all_gather(logits, 'batch') - # batch = jax.lax.all_gather(batch, 'batch') + if config.eval_configs.get('store_logits', False): + logits = jax.lax.all_gather(logits, 'batch') + batch = jax.lax.all_gather(batch, 'batch') confusion_matrix = jax.lax.all_gather(confusion_matrix, 'batch') unc_confusion_matrix = jax.lax.all_gather(unc_confusion_matrix, 'batch') @@ -656,9 +680,10 @@ def eval_step_baseline( # Collect predictions and batches from all hosts. # use all_gather to copy and replicate across all hosts # we can skip doing this for batch and logits to save memory - # is the OOM in tpu or cpu? - # batch = jax.lax.all_gather(batch, 'batch') - # logits = jax.lax.all_gather(logits, 'batch') + # jis the OOM in tpu or cpu? + if config.eval_configs.get('store_logits', False): + batch = jax.lax.all_gather(batch, 'batch') + logits = jax.lax.all_gather(logits, 'batch') return batch, logits @@ -741,28 +766,7 @@ def train( # Load pretrained backbone if start_step == 0 and config.get('load_pretrained_backbone', False): - # TODO(kellybuchanan): check out partial loader in - # https://github.com/google/uncertainty-baselines/commit/083b1dcc52bb1964f8917d15552ece8848d582ae# - restored_model_cfg = config.get('pretrained_backbone_configs') - - # Loader from scenic - if restored_model_cfg.checkpoint_format in ('ub', 'big_vision', 'scenic'): - # load params from checkpoint - bb_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint( - checkpoint_path=restored_model_cfg.checkpoint_path, - convert_to_linen=False) - - train_state = model.init_backbone_from_train_state( - train_state, - bb_train_state, - config, - restored_model_cfg, - model_prefix_path=['backbone']) - # Free unnecessary memory. - del bb_train_state - else: - raise NotImplementedError('') - + train_state = load_checkpoints_backbone(config, model, train_state, workdir) elif start_step == 0: logging.info('Not restoring from any pretrained_backbone.') @@ -816,7 +820,7 @@ def train( checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps train_metrics, extra_training_logs = [], [] - train_summary, eval_summary = None, None + train_summary, eval_summary = {}, {} global_metrics_fn = model.get_global_metrics_fn() # pytype: disable=attribute-error global_unc_metrics_fn = model.get_global_unc_metrics_fn() # pytype: disable=attribute-error @@ -883,12 +887,13 @@ def train( train_summary = train_utils.log_train_summary( step=step, - train_metrics=jax.tree_map(train_utils.unreplicate_and_get, + train_metrics=jax.tree_util.tree_map(train_utils.unreplicate_and_get, train_metrics), - extra_training_logs=jax.tree_map(train_utils.unreplicate_and_get, + extra_training_logs=jax.tree_util.tree_map(train_utils.unreplicate_and_get, extra_training_logs), writer=writer) + del example_viz, train_metrics, extra_training_logs # Reset metric accumulation for next evaluation cycle. train_metrics, extra_training_logs = [], [] @@ -908,6 +913,7 @@ def train( lead_host=lead_host, global_metrics_fn=global_metrics_fn, global_unc_metrics_fn=global_unc_metrics_fn, + workdir=workdir, ) # check accuracy for early stopping. @@ -1047,26 +1053,7 @@ def eval_ckpt( checkpoint_configs = config.get('checkpoint_configs', False) if checkpoint_configs: - # Load torch weights - if 'torch' in checkpoint_configs.checkpoint_format: - - bb_train_state = convert_torch_to_jax_checkpoint( - checkpoint_path=checkpoint_configs.checkpoint_path, - config=checkpoint_configs) - - train_state = model.init_backbone_from_train_state( - train_state, - bb_train_state, - config, - checkpoint_configs - ) - del bb_train_state - - # Load weights in checkpoint_path or workdir - else: - checkpoint_path = checkpoint_configs.get('checkpoint_path', workdir) - train_state, _ = train_utils.restore_checkpoint( - checkpoint_path, train_state) + train_state = load_checkpoints_eval(config, model, train_state, workdir) else: logging.info('Not loading any checkpoints') @@ -1103,6 +1090,7 @@ def eval_ckpt( global_metrics_fn=global_metrics_fn, global_unc_metrics_fn=global_unc_metrics_fn, prefix=prefix, + workdir=workdir, ) # Wait until computations are done before running robustness evaluator. @@ -1111,6 +1099,7 @@ def eval_ckpt( # ---------------------------------------------------------------------------- # Evaluate OOD datasets + logging.info('Evaluating OOD datasets') eval_summary_ood = evaluate_ood_step( train_state=train_state, config=config, @@ -1153,7 +1142,7 @@ def evaluate_ood_step( Returns: eval_summary: summary evaluation """ - del workdir + eval_summary = {} if config.get('eval_covariate_shift', False): @@ -1168,12 +1157,13 @@ def evaluate_ood_step( # We can donate the eval_batch's buffer. ) - eval_summary = None global_metrics_fn = model.get_global_metrics_fn() # pytype: disable=attribute-error global_unc_metrics_fn = model.get_global_unc_metrics_fn() # pytype: disable=attribute-error eval_ood_covariate = {'cityscapes_c': evaluate_cityscapes_c, - 'ade20k_ind_c': evaluate_ade20k_corrupted,} + 'ade20k_ind_c': evaluate_ade20k_corrupted, + 'street_hazards_c': evaluate_street_hazards_corrupted, + } # TODO(kellybuchanan): merge data sources. # The form of the ind dataset name depends on the source of the data. @@ -1189,22 +1179,28 @@ def evaluate_ood_step( elif any('ade20k' in ind_name for ind_name in ind_names): logging.info('Loading Ade20k_ind_c') ood_dataset = 'ade20k_ind_c' + elif any('street' in ind_name for ind_name in ind_names): + logging.info('Loading street_hazards_c') + ood_dataset = 'street_hazards_c' else: logging.info('OOD Covariate shift dataset is not implemented') + ood_dataset = None - eval_summary = eval_ood_covariate[ood_dataset]( - train_state=train_state, - config=config, - rng=rng, - eval_step_pmapped=eval_step_pmapped, - writer=writer, - lead_host=lead_host, - global_metrics_fn=global_metrics_fn, - global_unc_metrics_fn=global_unc_metrics_fn, - ) + if ood_dataset: + eval_summary = eval_ood_covariate[ood_dataset]( + train_state=train_state, + config=config, + rng=rng, + eval_step_pmapped=eval_step_pmapped, + writer=writer, + lead_host=lead_host, + global_metrics_fn=global_metrics_fn, + global_unc_metrics_fn=global_unc_metrics_fn, + workdir=workdir, + ) - # Wait until computations are done before exiting. - jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() + # Wait until computations are done before exiting. + jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() # ---------------------------------------------------------------------------- if config.get('eval_label_shift', False): @@ -1221,7 +1217,8 @@ def evaluate_ood_step( eval_label_shift = { 'fishyscapes': evaluate_fishyscapes, - 'ade20k_ood_open': evaluate_ade20k_ood_open + 'ade20k_ood_open': evaluate_ade20k_ood_open, + 'street_hazards_ood_open': evaluate_street_hazards_ood_open, } # The form of the ind dataset name depends on the source of the data. @@ -1234,25 +1231,29 @@ def evaluate_ood_step( if any('cityscapes' in ind_name for ind_name in ind_names): logging.info('Loading Fishyscapes...') ood_dataset = 'fishyscapes' - - if any('ade20k' in ind_name for ind_name in ind_names): + elif any('ade20k' in ind_name for ind_name in ind_names): logging.info('Loading ADE20k OOD OPEN...') ood_dataset = 'ade20k_ood_open' - + elif any('street_hazards' in ind_name for ind_name in ind_names): + logging.info('Loading StreetHazards OPEN...') + ood_dataset = 'street_hazards_ood_open' else: logging.info('OOD Label shift dataset is not implemented') + ood_dataset = None - eval_summary = eval_label_shift[ood_dataset]( - train_state=train_state, - config=config, - rng=rng, - eval_step_pmapped=eval_step_ood_pmapped, - writer=writer, - lead_host=lead_host, - ) + if ood_dataset: + eval_summary = eval_label_shift[ood_dataset]( + train_state=train_state, + config=config, + rng=rng, + eval_step_pmapped=eval_step_ood_pmapped, + writer=writer, + lead_host=lead_host, + workdir=workdir, + ) - # Wait until computations are done before exiting. - jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() + # Wait until computations are done before exiting. + jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() return eval_summary @@ -1265,6 +1266,7 @@ def evaluate_cityscapes_c( lead_host: Any, global_metrics_fn: Any, global_unc_metrics_fn: Any, + workdir: str = None, ) -> Dict[str, Any]: """Evaluate cityscapes-c dataset. @@ -1291,21 +1293,21 @@ def evaluate_cityscapes_c( # update config: ood_config = ml_collections.ConfigDict() ood_config.update(**config) - ood_config.update({'dataset_name': 'cityscapes_variants'}) + ood_config.update({'dataset_name': 'robust_segvit_variants'}) accuracy_per_corruption = {} - prefix = 'citycvalid' - for corruption in cityscapes_variants.CITYSCAPES_C_CORRUPTIONS: + prefix = 'cityc' + for corruption in datasets_info.CITYSCAPES_C_CORRUPTIONS: local_list = [] # list to compute macro average per corruption - for severity in cityscapes_variants.CITYSCAPES_C_SEVERITIES: + for severity in datasets_info.CITYSCAPES_C_SEVERITIES: with ood_config.unlocked(): - ood_config.dataset_configs.dataset_name = f'cityscapes_corrupted/semantic_segmentation_{corruption}_{severity}' + ood_config.dataset_configs.name = f'cityscapes_c_{corruption}_{severity}' rng, data_rng = jax.random.split(rng) dataset = train_utils.get_dataset(ood_config, data_rng) dataset.meta_data['dataset_name'] = 'cityscapes_c' - dataset.meta_data['prefix'] = prefix + f'_{corruption}_{severity}' + dataset.meta_data['prefix'] = prefix + f'/{corruption}/{severity}/valid' eval_summary = evaluate( train_state=train_state, @@ -1318,6 +1320,7 @@ def evaluate_cityscapes_c( global_metrics_fn=global_metrics_fn, global_unc_metrics_fn=global_unc_metrics_fn, prefix=dataset.meta_data['prefix'], + workdir=workdir, ) local_list.append(eval_summary) @@ -1331,7 +1334,7 @@ def evaluate_cityscapes_c( # append name to metrics key_separator = '_' avg_cityscapes_c_metrics = { - key_separator.join((prefix, key)): val + key_separator.join((prefix + '/valid', key)): val for key, val in cityscapes_c_metrics.items() } # update metrics @@ -1348,6 +1351,7 @@ def evaluate_fishyscapes( eval_step_pmapped: Any, writer: metric_writers.MetricWriter, lead_host: Any, + workdir: str = '', ) -> Dict[str, Any]: """Evaluate Fishyscapes dataset. @@ -1372,21 +1376,21 @@ def evaluate_fishyscapes( # update config: ood_config = ml_collections.ConfigDict() ood_config.update(**config) - ood_config.update({'dataset_name': 'cityscapes_variants'}) + ood_config.update({'dataset_name': 'robust_segvit_variants'}) device_count = jax.device_count() accuracy_per_corruption = {} - prefix = 'fishyvalid' - for corruption in cityscapes_variants.FISHYSCAPES_CORRUPTIONS: + prefix = 'fishyscapes' + for corruption in datasets_info.FISHYSCAPES_CORRUPTIONS: with ood_config.unlocked(): - ood_config.dataset_configs.dataset_name = f'fishyscapes/{corruption}' + ood_config.dataset_configs.name = f'fishyscapes/{corruption}' ood_config.batch_size = device_count data_rng, rng = jax.random.split(rng) dataset = train_utils.get_dataset(ood_config, data_rng) dataset.meta_data['dataset_name'] = 'fishyscapes' - dataset.meta_data['prefix'] = prefix + f'_{corruption}' + dataset.meta_data['prefix'] = prefix + f'/{corruption}/valid' eval_summary = evaluate_ood( train_state=train_state, @@ -1397,6 +1401,7 @@ def evaluate_fishyscapes( writer=writer, lead_host=lead_host, prefix=dataset.meta_data['prefix'], + workdir=workdir, **config.get('eval_robustness_configs', {}), ) @@ -1408,7 +1413,7 @@ def evaluate_fishyscapes( # append name to metrics key_separator = '_' avg_fishyscapes_metrics = { - key_separator.join((prefix, key)): val + key_separator.join((prefix +'/valid', key)): val for key, val in fishyscapes_metrics.items() } # update metrics @@ -1425,6 +1430,7 @@ def evaluate_ade20k_ood_open( eval_step_pmapped: Any, writer: metric_writers.MetricWriter, lead_host: Any, + workdir: str = '', ) -> Dict[str, Any]: """Evaluate ADE20k OOD dataset. @@ -1452,7 +1458,7 @@ def evaluate_ade20k_ood_open( ood_config.update({'dataset_name': 'robust_segvit_segmentation'}) device_count = jax.device_count() - prefix = 'ade20k_ood_open' + prefix = 'ade20k_ood_open/valid' with ood_config.unlocked(): ood_config.dataset_configs.name = 'ade20k_ood_open' @@ -1471,10 +1477,11 @@ def evaluate_ade20k_ood_open( writer=writer, lead_host=lead_host, prefix=dataset.meta_data['prefix'], + workdir=workdir, **config.get('eval_robustness_configs', {}), ) - # append name to metrics + # append name to metrics: key_separator = '_' avg_open_set_metrics = { key_separator.join((prefix, key)): val @@ -1497,6 +1504,7 @@ def evaluate_ade20k_corrupted( lead_host: Any, global_metrics_fn: Any, global_unc_metrics_fn: Any, + workdir : str, ) -> Dict[str, Any]: """Evaluate Ade20k-C dataset. @@ -1530,14 +1538,14 @@ def evaluate_ade20k_corrupted( prefix = 'ade20k_ind_c' for corruption in datasets_info.ADE20K_C_CORRUPTIONS: local_list = [] # list to compute macro average per corruption - for severity in range(1, 6): + for severity in datasets_info.ADE20K_C_SEVERITIES: with ood_config.unlocked(): ood_config.dataset_configs.name = f'ade20k_ind_c_{corruption}_{severity}' data_rng, rng = jax.random.split(rng) dataset = train_utils.get_dataset(ood_config, data_rng) - dataset.meta_data['prefix'] = prefix + f'_{corruption}_{severity}' + dataset.meta_data['prefix'] = prefix + f'/{corruption}/{severity}/valid' eval_summary = evaluate( train_state=train_state, @@ -1550,6 +1558,7 @@ def evaluate_ade20k_corrupted( global_metrics_fn=global_metrics_fn, global_unc_metrics_fn=global_unc_metrics_fn, prefix=dataset.meta_data['prefix'], + workdir=workdir, ) local_list.append(eval_summary) @@ -1563,7 +1572,94 @@ def evaluate_ade20k_corrupted( # append name to metrics key_separator = '_' avg_corrupted_metrics = { - key_separator.join((prefix, key)): val + key_separator.join((prefix + '/valid', key)): val + for key, val in ade20k_c_metrics.items() + } + # update metrics + eval_summary.update(avg_corrupted_metrics) + writer.write_scalars(0, avg_corrupted_metrics) + writer.flush() + return eval_summary + + +def evaluate_street_hazards_corrupted( + train_state: train_utils.TrainState, + config: ml_collections.ConfigDict, + rng: Any, + eval_step_pmapped: Any, + writer: metric_writers.MetricWriter, + lead_host: Any, + global_metrics_fn: Any, + global_unc_metrics_fn: Any, + workdir : str, +) -> Dict[str, Any]: + """Evaluate StreetHazards-C dataset. + + Args: + train_state: train state. + config: experiment configuration. + rng: jax rng. + eval_step_pmapped: eval state + writer: CLU metrics writer instance. + lead_host: Evaluate global metrics on one of the hosts (lead_host) given + intermediate values collected from all hosts. + global_metrics_fn: global metrics to evaluate. + global_unc_metrics_fn: global uncertainty metrics to evaluate. + Returns: + eval_summary: summary evaluation + """ + # Load dataset + # set resource limit to debug in mac osx + # (see https://github.com/tensorflow/datasets/issues/1441) + if jax.process_index() == 0 and sys.platform == 'darwin': + low, high = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (low, high)) + + # update config: + ood_config = ml_collections.ConfigDict() + ood_config.update(**config) + ood_config.update({'dataset_name': 'robust_segvit_variants'}) + + # Calculate metrics per corruption. + accuracy_per_corruption = {} + prefix = 'street_hazards_c' + for corruption in datasets_info.STREETHAZARDS_C_CORRUPTIONS: + local_list = [] # list to compute macro average per corruption + for severity in datasets_info.STREETHAZARDS_C_SEVERITIES: + + with ood_config.unlocked(): + ood_config.dataset_configs.name = f'street_hazards_c_{corruption}_{severity}' + + data_rng, rng = jax.random.split(rng) + dataset = train_utils.get_dataset(ood_config, data_rng) + dataset.meta_data['prefix'] = prefix + f'/{corruption}/{severity}/valid' + + eval_summary = evaluate( + train_state=train_state, + dataset=dataset, + config=ood_config, + step=0, + eval_step_pmapped=eval_step_pmapped, + writer=writer, + lead_host=lead_host, + global_metrics_fn=global_metrics_fn, + global_unc_metrics_fn=global_unc_metrics_fn, + prefix=dataset.meta_data['prefix'], + workdir=workdir, + ) + + local_list.append(eval_summary) + + accuracy_per_corruption[corruption] = eval_utils.average_list_of_dicts( + local_list) + + ade20k_c_metrics = eval_utils.average_list_of_dicts( + accuracy_per_corruption.values()) + + # append name to metrics + key_separator = '_' + avg_corrupted_metrics = { + key_separator.join((prefix + '/valid', key)): val for key, val in ade20k_c_metrics.items() } # update metrics @@ -1571,3 +1667,76 @@ def evaluate_ade20k_corrupted( writer.write_scalars(0, avg_corrupted_metrics) writer.flush() return eval_summary + + +def evaluate_street_hazards_ood_open( + train_state: train_utils.TrainState, + config: ml_collections.ConfigDict, + rng: Any, + eval_step_pmapped: Any, + writer: metric_writers.MetricWriter, + lead_host: Any, + workdir: str, +) -> Dict[str, Any]: + """Evaluate StreetHazards OOD dataset. + + Args: + train_state: train state. + config: experiment configuration. + rng: jax rng. + eval_step_pmapped: eval state + writer: CLU metrics writer instance. + lead_host: Evaluate global metrics on one of the hosts (lead_host) given + intermediate values collected from all hosts. + + Returns: + eval_summary: summary evaluation + """ + # set resource limit to debug in mac osx + # (see https://github.com/tensorflow/datasets/issues/1441) + if jax.process_index() == 0 and sys.platform == 'darwin': + low, high = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (low, high)) + + # update config: + ood_config = ml_collections.ConfigDict() + ood_config.update(**config) + ood_config.update({'dataset_name': 'robust_segvit_segmentation'}) + + device_count = jax.device_count() + prefix = 'street_hazards_open/valid' + + with ood_config.unlocked(): + ood_config.dataset_configs.name = 'street_hazards_open' + ood_config.batch_size = device_count + + data_rng, rng = jax.random.split(rng) + dataset = train_utils.get_dataset(ood_config, data_rng) + dataset.meta_data['prefix'] = prefix + + eval_summary = evaluate_ood( + train_state=train_state, + dataset=dataset, + config=ood_config, + step=0, + eval_step_pmapped=eval_step_pmapped, + writer=writer, + lead_host=lead_host, + prefix=dataset.meta_data['prefix'], + workdir=workdir, + **config.get('eval_robustness_configs', {}), + ) + + # append name to metrics + key_separator = '_' + avg_open_set_metrics = { + key_separator.join((prefix, key)): val + for key, val in eval_summary.items() + } + # update metrics + eval_summary.update(avg_open_set_metrics) + writer.write_scalars(0, avg_open_set_metrics) + writer.flush() + + return eval_summary + diff --git a/experimental/robust_segvit/custom_segmentation_trainer_test.py b/experimental/robust_segvit/custom_segmentation_trainer_test.py index 90b6c7553..5a6d06d42 100644 --- a/experimental/robust_segvit/custom_segmentation_trainer_test.py +++ b/experimental/robust_segvit/custom_segmentation_trainer_test.py @@ -34,7 +34,7 @@ from sklearn import metrics as sk_metrics import tensorflow as tf import custom_segmentation_trainer # local file import from experimental.robust_segvit - +import custom_models class SegmentationTrainerTest(parameterized.TestCase): """Tests the default trainer on single device setup.""" @@ -226,6 +226,61 @@ def test_get_confusion_matrix(self, seed, masked_fraction): self.assertAlmostEqual(metrics_dict['mean_iou'], miou_np, places=4) + @parameterized.parameters([(0, 0.0), (1, 0.01), (2, 0.5), (3, 0.99), (4, 1)]) + def test_unc_confusion_matrix(self, seed, masked_fraction): + """Test computation of mIoU metric.""" + np.random.seed(seed) + + # Create test data: + num_classes = 3 + input_shape = [8, 1, 224, 224] + logits_shape = input_shape + [num_classes] + logits_np = np.random.rand(*logits_shape) + logits = jnp.array(logits_np) + + # when the uncertainty threshold is 100% or = 0 + # all labels are certain, and pavpu is the fraction of patches that are accurate. + uncertainty_th = 0.0 + window_size = 1 + + # Note: We include label -1, which indicates excluded pixels: + label = np.random.randint(0, num_classes, size=input_shape) + label[:4] = np.argmax(logits_np[:4], axis=-1) # Set half to correct. + + batch_np = { + 'label': + label, + 'batch_mask': + (np.random.rand(*input_shape) > masked_fraction) & (label != -1), + } + batch = { + 'label': jnp.array(batch_np['label']), + 'batch_mask': jnp.array(batch_np['batch_mask']), + } + + cm_pmapped = jax.pmap( + functools.partial( + custom_segmentation_trainer.get_uncertainty_confusion_matrix, + uncertainty_th=uncertainty_th, + window_size=window_size, + uncertainty_measure='softmax', + ), axis_name='batch') + + unc_confusion_matrix = [ + cm_pmapped(labels=labels, logits=logits_, weights=masks) + for labels, logits_, masks in + zip(batch['label'], logits, batch['batch_mask'])] + unc_confusion_matrix = jax.device_get(jax_utils.unreplicate(unc_confusion_matrix)) + metrics_dict = custom_models.global_unc_metrics_fn( + unc_confusion_matrix) + labels_negative_ignored = np.maximum(batch_np['label'], 0) + y_pred = np.argmax(logits_np, axis=-1) + weights = batch_np['batch_mask'] + accurate = labels_negative_ignored == y_pred + pavpu = np.sum(accurate * weights) / np.sum(weights) + # all labels are certain, pavpu = fraction of patches that are accurate + self.assertAlmostEqual(metrics_dict['pacc_cert'], jnp.nan_to_num(pavpu), places=2) + if __name__ == '__main__': absltest.main() diff --git a/experimental/robust_segvit/metrics_multihost.py b/experimental/robust_segvit/metrics_multihost.py new file mode 100644 index 000000000..ae082fb60 --- /dev/null +++ b/experimental/robust_segvit/metrics_multihost.py @@ -0,0 +1,133 @@ +"""Calculate ood metrics across hosts. + +# How to communicate metrics across hosts? +# Ideally we can collect auc_metrics per host, merge them, compute result. +# However, we cannot pass arbitraty class. +# jax which doesn't work with arbitrary objects. + + +""" +from typing import Any, Optional, Dict + +import jax +import jax.numpy as jnp +import tensorflow as tf +from jax.experimental import multihost_utils +from ood_metrics import get_ood_score +from ood_metrics import get_score +import numpy as np +import copy + + +# Here we write a custom merge_state as in tf.keras.metrics +# by pulling states from tf.keras obj, combining them and putting them back +# into a keras object using list of host's auc_roc objects. +def keras_auc_to_arrays(keras_auc_object): + """Pull out arrays from keras roc object.""" + # The thresholds used are determinisitc, so we need not store them. + tp = jnp.asarray(keras_auc_object.true_positives) + fp = jnp.asarray(keras_auc_object.false_positives) + tn = jnp.asarray(keras_auc_object.true_negatives) + fn = jnp.asarray(keras_auc_object.false_negatives) + return tp, fp, tn, fn + + +def arrays_to_keras_auc(tp, fp, tn, fn, keras_auc_object): + """Assign confusion matrix arrays to a keras_auc_object.""" + keras_auc_object.true_positives.assign(tp) + keras_auc_object.false_positives.assign(fp) + keras_auc_object.true_negatives.assign(tn) + keras_auc_object.false_negatives.assign(fn) + return keras_auc_object + + +def combine_states(all_auc_states, num_thresholds=200): + # jax can take in trees of arrays, tuple is considered a tree so we can + # unpack it here. + # each array here has dimensions #host x shape + + all_tp, all_fp, all_tn, all_fn = all_auc_states + + assert all_tp.shape == (jax.process_count(), num_thresholds) + assert all_fp.shape == (jax.process_count(), num_thresholds) + assert all_tn.shape == (jax.process_count(), num_thresholds) + assert all_fn.shape == (jax.process_count(), num_thresholds) + + tp = jnp.sum(all_tp, 0) + fp = jnp.sum(all_fp, 0) + tn = jnp.sum(all_tn, 0) + fn = jnp.sum(all_fn, 0) + + return tp, fp, tn, fn + + +def host_all_gather_metrics(metric): + states = multihost_utils.process_allgather(metric.get_weights()) + state = jax.tree_util.tree_map(lambda x: np.sum(x, axis=0), states) + metric_copy = copy.deepcopy(metric) + metric_copy.set_weights(state) + return metric_copy + + +class ComputeAUCMetric: + """Calculate auc metrics across multiple hosts.""" + def __init__(self, curve, num_thresholds=200, from_logits=False): + self.curve = curve + self.num_thresholds = num_thresholds + self.from_logits = from_logits + self.auc = tf.keras.metrics.AUC(curve=self.curve, + from_logits=self.from_logits, + num_thresholds=self.num_thresholds) + + def calculate_and_update_scores(self, logits, label, sample_weight): + self.auc.update_state(label, logits, sample_weight=sample_weight) + + def gather_metrics(self): + auc_state = keras_auc_to_arrays(self.auc) + + # Gather the data across all hosts. + all_auc_states = multihost_utils.process_allgather(auc_state) + + # Below we pick the first device. + self.auc = arrays_to_keras_auc(*combine_states(all_auc_states, + num_thresholds=self.num_thresholds), + self.auc) + + return self.auc.result() + + +class ComputeOODAUCMetric: + """Calculate auc metrics across multiple hosts. + + Args: + curve: 'ROC' or 'PR' for the type of AUC. + num_thresholds: Number of thresholds to use for discretizing the roc curve. + from_logits: Whether `y_pred` is expected to be a logits tensor. If it is a logits tensor, + a sigmoid function is applied to the logits. + + """ + def __init__(self, curve, num_thresholds=200): + self.curve = curve + self.num_thresholds = num_thresholds + self.from_logits = False + self.auc = tf.keras.metrics.AUC(curve=self.curve, + from_logits=self.from_logits, + num_thresholds=self.num_thresholds) + + def calculate_and_update_scores(self, logits, label, sample_weight, **kwargs): + ood_score = get_ood_score(logits, **kwargs) + + # skip images where all the pixels are ood or there are no ood pixels + all_pixel_ood = jnp.sum(label*sample_weight) == 1 + no_pixel_ood = jnp.sum(label*sample_weight) == 0 + + if not(all_pixel_ood) and not(no_pixel_ood): + self.auc.update_state(label, ood_score, sample_weight=sample_weight) + + def gather_metrics(self): + + # Gather the metrics: + self.auc = host_all_gather_metrics(self.auc) + + return self.auc.result() + diff --git a/experimental/robust_segvit/metrics_multihost_test.py b/experimental/robust_segvit/metrics_multihost_test.py new file mode 100644 index 000000000..fb46d62fd --- /dev/null +++ b/experimental/robust_segvit/metrics_multihost_test.py @@ -0,0 +1,142 @@ +from absl.testing import absltest +from absl.testing import parameterized + +import jax +import jax.numpy as jnp +import numpy as np +from flax import jax_utils + +from metrics_multihost import ComputeAUCMetric +from metrics_multihost import ComputeOODAUCMetric +import sklearn.metrics + +from ood_metrics import get_ood_score + +class OODMetricsMultiHostTest(parameterized.TestCase): + + def setUp(self): + super(OODMetricsMultiHostTest, self).setUp() + + @parameterized.parameters([(0, 0.0), (1, 0.01), (2, 0.5), (3, 0.99), (4, 1)]) + def test_ComputeAUCMetric(self, seed, masked_fraction): + """Test computation of AUC metric.""" + np.random.seed(seed) + + from_logits = False # when set to True applies sigmoid to logits. + num_thresholds = 100 + + # Create test data: + num_classes = 2 + input_shape = [8, 1, 224, 224] + logits_shape = input_shape + [num_classes] + logits_np = np.random.rand(*logits_shape) + + # Note: We include label -1, which indicates excluded pixels: + label = np.random.randint(0, num_classes, size=input_shape) + label[:4] = np.argmax(logits_np[:4], axis=-1) # Set half to correct. + + batch_np = { + 'logits': logits_np, + 'label': + label, + 'batch_mask': + (np.random.rand(*input_shape) > masked_fraction) & (label != -1), + } + batch = { + 'logits': jnp.array(logits_np), + 'label': jnp.array(batch_np['label']), + 'batch_mask': jnp.array(batch_np['batch_mask']), + } + + fake_batches_replicated = jax_utils.replicate([batch]) + + auc_roc = ComputeAUCMetric(curve='ROC', num_thresholds=num_thresholds, from_logits=from_logits) + + for fake_batch in fake_batches_replicated: + if from_logits: + pred = jnp.max(fake_batch['logits'], axis=-1) + else: + pred = jnp.argmax(fake_batch['logits'], axis=-1) + auc_roc.calculate_and_update_scores(logits=pred, + label=fake_batch['label'], + sample_weight=fake_batch['batch_mask'], + ) + + auc_result = auc_roc.gather_metrics().numpy() + + # Numpy result: + if np.all(batch_np['batch_mask'] == 0): + auc_numpy = 0 + else: + labels_negative_ignored = np.maximum(batch_np['label'], 0) + y_pred = np.argmax(logits_np, axis=-1) + auc_numpy = sklearn.metrics.roc_auc_score(labels_negative_ignored.ravel(), + y_pred.ravel(), + sample_weight=batch_np['batch_mask'].ravel()) + + self.assertAlmostEqual(auc_result, auc_numpy, places=2) + + + @parameterized.parameters([(0, 0.0), (1, 0.01), (2, 0.5), (3, 0.99), (4, 1)]) + def test_ComputeOODAUCMetric(self, seed, masked_fraction): + """Test computation of OOD scored AUC metric.""" + np.random.seed(seed) + num_thresholds = 1000 + + ood_kwargs = { + 'method_name': 'mlogit', + } + # Create test data: + num_classes = 2 + input_shape = [8, 1, 224, 224] + logits_shape = input_shape + [num_classes] + logits_np = np.random.rand(*logits_shape) + + # Note: We include label -1, which indicates excluded pixels: + label = np.random.randint(0, num_classes, size=input_shape) + label[:4] = np.argmax(logits_np[:4], axis=-1) # Set half to correct. + + batch_np = { + 'logits': logits_np, + 'label': + label, + 'batch_mask': + (np.random.rand(*input_shape) > masked_fraction) & (label != -1), + } + batch = { + 'logits': jnp.array(logits_np), + 'label': jnp.array(batch_np['label']), + 'batch_mask': jnp.array(batch_np['batch_mask']), + } + + fake_batches_replicated = jax_utils.replicate([batch]) + + auc_roc = ComputeOODAUCMetric(curve='ROC', num_thresholds=num_thresholds) + + for fake_batch in fake_batches_replicated: + pred = fake_batch['logits'] + ood_label = 1 - fake_batch['label'] + + auc_roc.calculate_and_update_scores(logits=pred, + label=ood_label, + sample_weight=fake_batch['batch_mask'], + **ood_kwargs, + ) + auc_result = auc_roc.gather_metrics().numpy() + + # Numpy result: + if np.all(batch_np['batch_mask'] == 0): + auc_numpy = 0 + else: + labels_negative_ignored = np.maximum(batch_np['label'], 0) + ood_label_np = 1 - labels_negative_ignored + ood_score = get_ood_score(logits_np, **ood_kwargs) + auc_numpy = sklearn.metrics.roc_auc_score(ood_label_np.ravel(), + ood_score.ravel(), + sample_weight=batch_np['batch_mask'].ravel()) + + self.assertAlmostEqual(auc_result, auc_numpy, places=1) + + +if __name__ == '__main__': + absltest.main() diff --git a/experimental/robust_segvit/ood_metrics.py b/experimental/robust_segvit/ood_metrics.py index 36e3b2ce5..2f273b02c 100644 --- a/experimental/robust_segvit/ood_metrics.py +++ b/experimental/robust_segvit/ood_metrics.py @@ -81,7 +81,7 @@ def preprocess_outlier(outlier): def get_ood_score( logits: jnp.ndarray, - method_name: str = 'msp', + method_name: str = 'nmlogit', num_top_k: int = 5, ) -> Dict[str, Any]: """Get OOD score.""" @@ -97,6 +97,9 @@ def get_ood_score( elif method_name == 'mlogit': max_logits = jnp.max(logits, -1) ood_score = 1 - max_logits + elif method_name == 'nmlogit': + max_logits = jnp.max(logits, -1) + ood_score = - 1 * max_logits elif method_name == 'sum_topklogit': ood_score = jax.lax.top_k(logits, num_top_k)[0].sum(-1) elif method_name == '1-sum_topklogit': @@ -108,6 +111,31 @@ def get_ood_score( return ood_score +def get_score( + logits: jnp.ndarray, + method_name: str = 'mlogit', + num_top_k: int = 5, + ) -> Dict[str, Any]: + """Get OOD score.""" + + if method_name == 'msp': + probs = jax.nn.softmax(logits, -1) + ood_score = jnp.max(probs, -1) + elif method_name == 'entropy': + probs = jax.nn.softmax(logits, -1) + entropy = -jnp.sum(probs * jnp.log(probs), axis=-1) + ood_score = entropy + elif method_name == 'mlogit': + ood_score = jnp.max(logits, -1) + elif method_name == 'sum_topklogit': + ood_score = jax.lax.top_k(logits, num_top_k)[0].sum(-1) + else: + raise NotImplementedError( + f'Missing method {method_name} to calculate OOD score.') + return ood_score + + + def get_ood_metrics( logits: jnp.ndarray, ood_mask: jnp.ndarray, @@ -148,11 +176,9 @@ def get_ood_metrics( # the weights per entry are 1 if it should be included during computation # and 0 otherwise. - # the masked array makes any entry with value 1 as invalid. - y_true_masked = np.ma.masked_array(y_true, mask=1-weights) - ood_score_masked = np.ma.masked_array(ood_score, mask=1-weights) + y_true_masked = y_true[weights == 1] + ood_score_masked = ood_score[weights == 1] - metrics = compute_ood_metrics(y_true_masked.flatten(), - ood_score_masked.flatten()) + metrics = compute_ood_metrics(y_true_masked.flatten(), ood_score_masked.flatten()) return metrics diff --git a/experimental/robust_segvit/pretrainer_utils.py b/experimental/robust_segvit/pretrainer_utils.py index a5c4ffbcc..a3b1d4d76 100644 --- a/experimental/robust_segvit/pretrainer_utils.py +++ b/experimental/robust_segvit/pretrainer_utils.py @@ -22,6 +22,7 @@ import numpy as np from scenic.train_lib_deprecated import train_utils from tensorflow.io import gfile +from flax.training import checkpoints def load_bb_config( @@ -197,6 +198,59 @@ def convert_torch_to_jax_checkpoint( optimizer={"target": restored_params},) # pytype: enable=wrong-arg-types + return restored_train_state + + +def convert_vision_transformer_to_scenic( + checkpoint_path: str, + convert_to_linen: bool = True) -> train_utils.TrainState: + """Converts a vision_transformer checkpoint to an scenic train state. + + The model weights come from https://github.com/google-research/vision_transformer. + + Original code: convert_big_vision_to_scenic_checkpoint + from https://github.com/google-research/scenic/ + + Args: + checkpoint_path: Path to checkpoint. + convert_to_linen: Whether to convert to Linen format. + + Returns: + restored_train_state: Scenic train state with model weights, global step + and accumulated training time. + """ + + def unflatten_dict(flattened: Dict[str, Any], + separator: str = '/', + leaf_idx: int = -1) -> Dict[str, Any]: + unflattened = {} + for k, v in flattened.items(): + subtree = unflattened + if leaf_idx != 0: + path = k.split(separator)[:leaf_idx] + else: + path = k.split(separator) + for k2 in path[:-1]: + if k2 not in subtree: + subtree[k2] = {} + subtree = subtree[k2] + subtree[path[-1]] = v + return unflattened + + logging.info('Loading vision_transformer checkpoint from %s', checkpoint_path) + checkpoint_data = np.load(gfile.GFile(checkpoint_path, 'rb')) + restored_params = unflatten_dict(checkpoint_data, separator='/', leaf_idx=0) + + if convert_to_linen: + restored_params = checkpoints.convert_pre_linen(restored_params) + restored_params = dict(restored_params) + + train_state = train_utils.TrainState() + # pytype: disable=wrong-arg-types + restored_train_state = train_state.replace( # pytype: disable=attribute-error + optimizer={"target": restored_params},) + # pytype: enable=wrong-arg-types + # free memory del restored_params - return restored_train_state + return restored_train_state \ No newline at end of file diff --git a/experimental/robust_segvit/run_ade20k_ind_be_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_be_eval.yaml new file mode 100755 index 000000000..90863139c --- /dev/null +++ b/experimental/robust_segvit/run_ade20k_ind_be_eval.yaml @@ -0,0 +1,39 @@ +name: ade20k_ind_be_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/ade20k_ind/be_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/ade20k_ind/be_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_ade20k_ind_deterministic_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_deterministic_eval.yaml new file mode 100755 index 000000000..01347a6a9 --- /dev/null +++ b/experimental/robust_segvit/run_ade20k_ind_deterministic_eval.yaml @@ -0,0 +1,39 @@ +name: ade20k_ind_deterministic_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: true + config.eval_robustness_configs.method_name: + value: 'msp' + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/ade20k_ind/deterministic_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/ade20k_ind/deterministic_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_ade20k_ind_gp_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_gp_eval.yaml new file mode 100755 index 000000000..0612dcdf1 --- /dev/null +++ b/experimental/robust_segvit/run_ade20k_ind_gp_eval.yaml @@ -0,0 +1,40 @@ +name: ade20k_ind_gp_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/ade20k_ind/gp_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/ade20k_ind/gp_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_ade20k_ind_het_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_het_eval.yaml new file mode 100755 index 000000000..0ed843f4b --- /dev/null +++ b/experimental/robust_segvit/run_ade20k_ind_het_eval.yaml @@ -0,0 +1,39 @@ +name: ade20k_ind_het_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/ade20k_ind/het_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/ade20k_ind/het_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_cityscapes_be_eval.yaml b/experimental/robust_segvit/run_cityscapes_be_eval.yaml new file mode 100755 index 000000000..b0afa8ebc --- /dev/null +++ b/experimental/robust_segvit/run_cityscapes_be_eval.yaml @@ -0,0 +1,40 @@ +name: cityscapes_be_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/be_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/cityscapes/be_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_cityscapes_deterministic_eval.yaml b/experimental/robust_segvit/run_cityscapes_deterministic_eval.yaml new file mode 100755 index 000000000..b2eced3cc --- /dev/null +++ b/experimental/robust_segvit/run_cityscapes_deterministic_eval.yaml @@ -0,0 +1,39 @@ +name: cityscapes_deterministic_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/deterministic_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/cityscapes/deterministic_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_cityscapes_gp_eval.yaml b/experimental/robust_segvit/run_cityscapes_gp_eval.yaml new file mode 100755 index 000000000..ff44994d1 --- /dev/null +++ b/experimental/robust_segvit/run_cityscapes_gp_eval.yaml @@ -0,0 +1,40 @@ +name: cityscapes_gp_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/gp_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/cityscapes/gp_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_cityscapes_het_eval.yaml b/experimental/robust_segvit/run_cityscapes_het_eval.yaml new file mode 100755 index 000000000..671486f60 --- /dev/null +++ b/experimental/robust_segvit/run_cityscapes_het_eval.yaml @@ -0,0 +1,40 @@ +name: cityscapes_het_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/het_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/cityscapes/het_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_deterministic_cityscapes.yaml b/experimental/robust_segvit/run_deterministic_cityscapes.yaml new file mode 100755 index 000000000..a9bb2eabd --- /dev/null +++ b/experimental/robust_segvit/run_deterministic_cityscapes.yaml @@ -0,0 +1,34 @@ +name: deterministic_cityscapes +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/deterministic.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/cityscapes/deterministic" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_deterministic_local.sh b/experimental/robust_segvit/run_deterministic_local.sh new file mode 100755 index 000000000..19c7a2d6a --- /dev/null +++ b/experimental/robust_segvit/run_deterministic_local.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# train toy model using wandb +#wandb sweep run_toy_mac.yaml +# before make sure we can run code vanilla version: + +DATASET='ade20k_ind' # or cityscapes +DATASET='street_hazards' +DATASET='cityscapes' + +base_output_dir="gs://ub-ekb/segmenter/${DATASET}/deterministic" + +# Debug on Mac OS X platform +use_gpu=False + +if [ "$(uname)" = "Darwin" ] ; then +tpu=False +num_cores=1 +batch_size=5 +elif [ "$(uname)" = "Linux" ]; then +tpu='local' +num_cores=8 +batch_size=8 +fi + +use_wandb=True + +config_file="configs/${DATASET}/deterministic.py:runlocal" +run_name="local" +output_dir="${base_output_dir}/${run_name}" +python deterministic.py \ +--output_dir=${output_dir} \ +--num_cores=$num_cores \ +--use_gpu=$use_gpu \ +--config=${config_file} \ +--config.batch_size=${batch_size} \ +--config.use_wandb=${use_wandb} \ +--tpu=${tpu} \ diff --git a/experimental/robust_segvit/run_deterministic_street_hazards.yaml b/experimental/robust_segvit/run_deterministic_street_hazards.yaml new file mode 100755 index 000000000..182f288e0 --- /dev/null +++ b/experimental/robust_segvit/run_deterministic_street_hazards.yaml @@ -0,0 +1,38 @@ +name: deterministic_street_hazards_hparam +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.lr_configs.base_learning_rate: + values: [0.0001, 0.00001, 0.0003] + config.num_training_epochs: + values: [50, 100] + + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/street_hazards/deterministic.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/street_hazards/deterministic" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_eval.sh b/experimental/robust_segvit/run_eval.sh new file mode 100644 index 000000000..f2f120687 --- /dev/null +++ b/experimental/robust_segvit/run_eval.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# call eval model using wandb + +# Debug on Mac OS X platform +use_gpu=False +if [ "$(uname)" = "Darwin" ] ; then +tpu=False +num_cores=1 +batch_size=5 +elif [ "$(uname)" = "Linux" ]; then +tpu='local' +num_cores=8 +batch_size=8 +fi + +# default config for eval +eval_covariate_shift=False +method_name='msp' +use_wandb=True + +for dataset in "ade20k_ind" "street_hazards" "cityscapes" +do +for model in "gp" "be" "deterministic" "het" +do +base_output_dir="gs://ub-ekb/segmenter/${dataset}" +config_file="configs/${dataset}/${model}_eval.py" +run_name="${model}_eval" +output_dir="${base_output_dir}/${run_name}" +python deterministic.py \ +--output_dir=${output_dir} \ +--num_cores=$num_cores \ +--use_gpu=$use_gpu \ +--config=${config_file} \ +--config.batch_size=${batch_size} \ +--config.eval_robustness_configs.method_name=${method_name} \ +--config.eval_covariate_shift=${eval_covariate_shift} \ +--config.use_wandb=${use_wandb} \ +--tpu=${tpu} \ + +done +done diff --git a/experimental/robust_segvit/run_eval_local.sh b/experimental/robust_segvit/run_eval_local.sh new file mode 100755 index 000000000..16fed5f45 --- /dev/null +++ b/experimental/robust_segvit/run_eval_local.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# evaluate model using wandb +#wandb sweep run_toy_mac.yaml +# before make sure we can run code vanilla version: + +DATASET='ade20k_ind' # or cityscapes +DATASET='street_hazards' + +# Parameters +DATASET='cityscapes' +model='deterministic' + +base_output_dir="gs://ub-ekb/segmenter/${DATASET}/${model}_eval" + +# Debug on Mac OS X platform +use_gpu=False +if [ "$(uname)" = "Darwin" ] ; then +tpu=False +num_cores=1 +batch_size=1 +elif [ "$(uname)" = "Linux" ]; then +tpu='local' +num_cores=8 +batch_size=8 +fi + +config_file="configs/${DATASET}/${model}_eval.py:runlocal" +run_name="${model}_eval" +output_dir="${base_output_dir}/${run_name}" +python deterministic.py \ +--output_dir=${output_dir} \ +--num_cores=$num_cores \ +--use_gpu=$use_gpu \ +--config=${config_file} \ +--config.batch_size=${batch_size} \ +--tpu=${tpu} \ diff --git a/experimental/robust_segvit/run_street_hazards_be.yaml b/experimental/robust_segvit/run_street_hazards_be.yaml new file mode 100755 index 000000000..98742c2c9 --- /dev/null +++ b/experimental/robust_segvit/run_street_hazards_be.yaml @@ -0,0 +1,37 @@ +name: be_street_hazards_hparam +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 24 + config.model.backbone.random_sign_init: + values: [-0.5, -0.25, 0.25, 0.5] + config.fast_weight_lr_multiplier: + values: [0.5, 1.0, 2.0] + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/street_hazards/be.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/street_hazards/be" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_street_hazards_deterministic_eval.yaml b/experimental/robust_segvit/run_street_hazards_deterministic_eval.yaml new file mode 100755 index 000000000..78e7d7ee0 --- /dev/null +++ b/experimental/robust_segvit/run_street_hazards_deterministic_eval.yaml @@ -0,0 +1,39 @@ +name: street_hazards_deterministic_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 8 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/street_hazards/deterministic_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/street_hazards/deterministic_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_street_hazards_gp.yaml b/experimental/robust_segvit/run_street_hazards_gp.yaml new file mode 100755 index 000000000..f6bbdf6b9 --- /dev/null +++ b/experimental/robust_segvit/run_street_hazards_gp.yaml @@ -0,0 +1,35 @@ +name: gp_street_hazards_hparam +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 24 + config.model.decoder.mean_field_factor: + values: [1, 2, 5, 6, 10] + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/street_hazards/gp.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/street_hazards/gp" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_street_hazards_gp_eval.yaml b/experimental/robust_segvit/run_street_hazards_gp_eval.yaml new file mode 100755 index 000000000..988f79c2b --- /dev/null +++ b/experimental/robust_segvit/run_street_hazards_gp_eval.yaml @@ -0,0 +1,39 @@ +name: street_hazards_gp_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 8 + config.eval_configs.store_logits: + value: false + config.eval_covariate_shift: + value: false + config.eval_robustness_configs.method_name: + value: 'msp' + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/street_hazards/gp_eval.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/street_hazards/gp_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_street_hazards_het.yaml b/experimental/robust_segvit/run_street_hazards_het.yaml new file mode 100755 index 000000000..af4fd2799 --- /dev/null +++ b/experimental/robust_segvit/run_street_hazards_het.yaml @@ -0,0 +1,35 @@ +name: het_street_hazards_hparam +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 24 + config.model.decoder.temperature: + values: [0.15, 0.3, 1, 1.5, 2.0] + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/street_hazards/het.py" + - "--output_dir" + - "gs://ub-ekb/segmenter/street_hazards/het" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_toy_eval.sh b/experimental/robust_segvit/run_toy_eval.sh new file mode 100755 index 000000000..67426937b --- /dev/null +++ b/experimental/robust_segvit/run_toy_eval.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# train toy model using wandb +#wandb sweep run_toy_mac.yaml +# before make sure we can run code vanilla version: + +DATASET='ade20k_ind' # or cityscapes +DATASET='street_hazards' +DATASET='cityscapes' + +base_output_dir="gs://ub-ekb/segmenter/${DATASET}/toy_model" + +# Debug on Mac OS X platform +use_gpu=False +if [ "$(uname)" = "Darwin" ] ; then +tpu=False +num_cores=1 +batch_size=5 +elif [ "$(uname)" = "Linux" ]; then +tpu='local' +num_cores=8 +batch_size=8 +fi + +use_wandb=True +config_file="configs/${DATASET}/toy_model_eval.py:runlocal" +run_name="toy_model_eval" +output_dir="${base_output_dir}/${run_name}" +python deterministic.py \ +--output_dir=${output_dir} \ +--num_cores=$num_cores \ +--use_gpu=$use_gpu \ +--config=${config_file} \ +--config.batch_size=${batch_size} \ +--config.use_wandb=${use_wandb} \ +--tpu=${tpu} \ diff --git a/experimental/robust_segvit/run_toy_mac.sh b/experimental/robust_segvit/run_toy_mac.sh new file mode 100755 index 000000000..adc4c613c --- /dev/null +++ b/experimental/robust_segvit/run_toy_mac.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# ---------------------------------------------------- +# train toy model on a DATASET: +# ---------------------------------------------------- + +# to train toy model and track performance using wandb: +# wandb sweep run_toy_mac.yaml + +DATASET='ade20k_ind' +DATASET='street_hazards' +DATASET='cityscapes' + +# ---------------------------------------------------- +# Set directory where outputs should be installed: +# ---------------------------------------------------- +# can write results directly to gcp bucket +# base_output_dir="gs://ub-ekb/segmenter/${DATASET}/toy_model" +dt=$(date +"%Y-%m-%d-%H-%M-%S") + +base_output_dir="results/${DATASET}" + +run_name="toy_model" +output_dir="${base_output_dir}/${run_name}/${dt}" +# ---------------------------------------------------- +# Set device configuration for Mac OS X platform +# or TPU v2-8/v3-8 frameworks. +# ---------------------------------------------------- +use_gpu=False +if [ "$(uname)" = "Darwin" ] ; then +tpu=False +num_cores=1 +batch_size=5 +elif [ "$(uname)" = "Linux" ]; then +tpu='local' +num_cores=8 +batch_size=8 +fi + +# ---------------------------------------------------- +# Set configuration file +# ---------------------------------------------------- +config_file="configs/${DATASET}/toy_model.py:runlocal" +use_wandb=True +eval_covariate_shift=False + +# ---------------------------------------------------- +# Call model trainer. +# ---------------------------------------------------- +python deterministic.py \ +--output_dir=${output_dir} \ +--num_cores=$num_cores \ +--use_gpu=$use_gpu \ +--config=${config_file} \ +--config.batch_size=${batch_size} \ +--config.eval_covariate_shift=${eval_covariate_shift} \ +--config.use_wandb=${use_wandb} \ +--tpu=${tpu} \ diff --git a/experimental/robust_segvit/run_toy_mac.yaml b/experimental/robust_segvit/run_toy_mac.yaml new file mode 100755 index 000000000..dba4db0bf --- /dev/null +++ b/experimental/robust_segvit/run_toy_mac.yaml @@ -0,0 +1,34 @@ +name: toy_model +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 8 + + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/toy_model.py:runlocal" + - "--output_dir" + - "gs://ub-ekb/segmenter/cityscapes/toy_model" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/run_train_seed.sh b/experimental/robust_segvit/run_train_seed.sh new file mode 100644 index 000000000..0625dcc43 --- /dev/null +++ b/experimental/robust_segvit/run_train_seed.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# call eval model using wandb + +# Debug on Mac OS X platform +use_gpu=False +if [ "$(uname)" = "Darwin" ] ; then +tpu=False +num_cores=1 +batch_size=5 +elif [ "$(uname)" = "Linux" ]; then +tpu='local' +num_cores=8 +batch_size=16 +fi + +# default config for eval +use_wandb=True + +for dataset in "cityscapes" #"ade20k_ind" "street_hazards" +do +for model in "deterministic" "gp" "het" "be" +do +for rng_seed in 1 +do +base_output_dir="gs://ub-ekb/segmenter/${dataset}" +config_file="configs/${dataset}/${model}.py" +run_name="${model}_eval" +output_dir="${base_output_dir}/${run_name}" +python deterministic.py \ +--output_dir=${output_dir} \ +--num_cores=$num_cores \ +--use_gpu=$use_gpu \ +--config=${config_file} \ +--config.batch_size=${batch_size} \ +--config.use_wandb=${use_wandb} \ +--config.rng_seed=${rng_seed} \ +--tpu=${tpu} \ + +done +done +done diff --git a/experimental/robust_segvit/store_cityscapes_be_eval.yaml b/experimental/robust_segvit/store_cityscapes_be_eval.yaml new file mode 100755 index 000000000..16d6baa9f --- /dev/null +++ b/experimental/robust_segvit/store_cityscapes_be_eval.yaml @@ -0,0 +1,37 @@ +name: cityscapes_be_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: true + config.eval_covariate_shift: + value: false + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/be_eval.py" + - "--output_dir" + - "cityscapes/be_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/store_cityscapes_deterministic_eval.yaml b/experimental/robust_segvit/store_cityscapes_deterministic_eval.yaml new file mode 100755 index 000000000..f57045b45 --- /dev/null +++ b/experimental/robust_segvit/store_cityscapes_deterministic_eval.yaml @@ -0,0 +1,37 @@ +name: cityscapes_deterministic_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: true + config.eval_covariate_shift: + value: false + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/deterministic_eval.py" + - "--output_dir" + - "cityscapes/deterministic_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/store_cityscapes_gp_eval.yaml b/experimental/robust_segvit/store_cityscapes_gp_eval.yaml new file mode 100755 index 000000000..87b82c3cb --- /dev/null +++ b/experimental/robust_segvit/store_cityscapes_gp_eval.yaml @@ -0,0 +1,37 @@ +name: cityscapes_gp_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: true + config.eval_covariate_shift: + value: false + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/gp_eval.py" + - "--output_dir" + - "cityscapes/gp_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/store_cityscapes_het_eval.yaml b/experimental/robust_segvit/store_cityscapes_het_eval.yaml new file mode 100755 index 000000000..ba6eadaff --- /dev/null +++ b/experimental/robust_segvit/store_cityscapes_het_eval.yaml @@ -0,0 +1,37 @@ +name: cityscapes_het_eval +program: deterministic.py +method: grid +project: rdl-debug +entity: ekellbuch + +metric: + name: valid_loss + goal: minimize +parameters: + config.use_wandb: + value: true + config.wandb_project : + value: ${{project}} + config.wandb_entity : + value: ${{entity}} + config.batch_size: + value: 16 + config.eval_configs.store_logits: + value: true + config.eval_covariate_shift: + value: false + + +command: + - ${env} + - python + - ${program} + - "--config" + - "configs/cityscapes/het_eval.py" + - "--output_dir" + - "cityscapes/het_eval" + - "--num_cores" + - "8" + - "--tpu" + - "local" + - ${args} \ No newline at end of file diff --git a/experimental/robust_segvit/uncertainty_metrics.py b/experimental/robust_segvit/uncertainty_metrics.py index 031c1fecb..e22934c9f 100644 --- a/experimental/robust_segvit/uncertainty_metrics.py +++ b/experimental/robust_segvit/uncertainty_metrics.py @@ -15,11 +15,10 @@ """Calculate uncertainty metrics for segmentation tasks.""" from typing import Optional, Tuple +import jax from jax import lax import jax.numpy as jnp -from scenic.model_lib.base_models.model_utils import apply_weights - -# TODO(kellybuchanan): reconcile cases where mask is 0. +from scenic.model_lib.layers import nn_ops def calculate_num_patches_binary_maps( @@ -29,7 +28,7 @@ def calculate_num_patches_binary_maps( Args: binary_acc_map : binary accuracy map - binary_unc_map : binary uncertainty map + binary_unc_map : binary uncertainty map (1=certain, 0=uncertain) Returns: metrics to calculate uncertainty scores @@ -37,24 +36,24 @@ def calculate_num_patches_binary_maps( # number of patches that are accurate and certain n_ac = jnp.sum( jnp.logical_and( - jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 0)), + jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 1)), axis=(-1, -2)) # number of patches that are inaccurate and certain n_ic = jnp.sum( jnp.logical_and( - jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 0)), + jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 1)), axis=(-1, -2)) # number of patches that are inaccurate and uncertain n_iu = jnp.sum( jnp.logical_and( - jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 1)), + jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 0)), axis=(-1, -2)) # number of patches that are accurate and uncertain n_au = jnp.sum( jnp.logical_and( - jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 1)), + jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 0)), axis=(-1, -2)) unc_confusion_matrix = jnp.stack((n_ac, n_ic, n_iu, n_au), axis=-1) @@ -86,17 +85,18 @@ def get_pavpu(unc_confusion_matrix): def get_uncertainty_confusion_matrix( + *, logits: jnp.ndarray, labels: jnp.ndarray, + uncertainty_measure: str = 'softmax', + accuracy_measure : str = 'predictive_accuracy', weights: Optional[jnp.ndarray] = None, accuracy_th: Optional[float] = 0.5, uncertainty_th: Optional[float] = 0.5, - window_size: Optional[int] = 2 + window_size: Optional[int] = 2, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Calculate counts of patches accurate/inacurate and certain/uncertain. - TODO(kellybuchanan): include weights for entropy calculation. - Args: logits: predicted logits labels: true labels @@ -118,29 +118,43 @@ def get_uncertainty_confusion_matrix( preds = jnp.argmax(logits, axis=-1) # calculate binary accuracy map - correct = jnp.equal(preds, targets) + correct = jnp.equal(preds, targets).astype(jnp.float32) - # batch masking - if weights is not None: - correct = apply_weights(correct, weights) + if weights is None: + weights = jnp.ones(correct.shape) - correct = correct.astype(jnp.float32) + weights = weights.astype(jnp.float32) - # A given patch is accurate if its acc > accuracy_threshold - binary_acc_map = reduce_2dmap(correct, window_size, - accuracy_th).astype(jnp.float32) + if accuracy_measure == 'predictive_accuracy': + accuracy_map = correct + else: + raise NotImplementedError('Accuracy measure not implemented.') - # Calculate uncertainty map - entropy = get_entropy_from_logits(logits) + # A given patch is accurate if its acc > accuracy_threshold + binary_acc_map = reduce_2dmap_weighted(accuracy_map, + weights, + window_size=window_size, + threshold=accuracy_th).astype(jnp.float32) + + # Calculate uncertainty map: + if uncertainty_measure == 'softmax': + uncertainty_map = jnp.max(jax.nn.softmax(logits, -1), -1) + elif uncertainty_measure == 'entropy': + uncertainty_map = get_entropy_from_logits(logits) + else: + raise NotImplementedError(f'Uncertainty measure {uncertainty_measure} not implemented.') - # A given patch is uncertain if its uncertainty > uncertainty_th - binary_unc_map = reduce_2dmap(entropy, window_size, - uncertainty_th).astype(jnp.float32) + # A given patch is certain if its uncertainty > uncertainty_th + binary_unc_map = reduce_2dmap_weighted(uncertainty_map, + weights, + window_size=window_size, + threshold=uncertainty_th).astype(jnp.float32) # number of patches that are accurate and certain unc_confusion_matrix = calculate_num_patches_binary_maps( binary_acc_map, binary_unc_map) + unc_confusion_matrix = unc_confusion_matrix[jnp.newaxis, ...] # Dummy batch dim. return unc_confusion_matrix @@ -159,14 +173,14 @@ def reduce_2dmap( """Given a map, apply a 2d spatial strided convolution to avg adjacent values. Args: - array_map: array to be split. + array_map: array to be split. 3-D Tensor; With shape `[batch, in_rows, in_cols]. window_size: size of window. threshold: threshold for binarization. Returns: binary_map: binary map. """ - # Expand dimension to match filter C dimension. + # Expand dimension for dummy depth dimension array_map = jnp.expand_dims(array_map, -1) # Create a kernel @@ -193,6 +207,45 @@ def reduce_2dmap( return binary_map.astype(jnp.int32) +def reduce_2dmap_weighted( + array_map: jnp.ndarray, + weights: jnp.ndarray, + window_size: int = 4, + threshold: float = 0.5, +) -> jnp.ndarray: + """Given a map, apply a pooling operation to avg adjacent values. + + Args: + array_map: array to be split. 3-D Tensor; With shape `[batch, in_rows, in_cols]. + weights: array of weights. 3-D Tensor; With shape `[batch, in_rows, in_cols]. + window_size: size of window. + threshold: threshold for binarization. + data_format: str; The format of the `lhs`. Must be either `'NHWC'` or `'NCHW'`. + + Returns: + binary_map: binary map. + """ + # Expand dimension for dummy feature dimension + array_map = jnp.expand_dims(array_map, -1) + + window_shape = (window_size, window_size) + + outputs = nn_ops.weighted_avg_pool( + array_map, + weights, + window_shape=window_shape, + strides=window_shape, + padding='VALID') + + # Binarize_map according to threshold + binary_map = jnp.greater_equal(outputs, threshold) + + # Squeeze dummy feature dimension + binary_map = jnp.squeeze(binary_map, -1) + + return binary_map.astype(jnp.int32) + + class SegmentationUncertaintyMetrics(object): """Calculate uncertainty scores for image segmentation task.""" diff --git a/experimental/robust_segvit/uncertainty_metrics_test.py b/experimental/robust_segvit/uncertainty_metrics_test.py deleted file mode 100644 index c3a4c2a78..000000000 --- a/experimental/robust_segvit/uncertainty_metrics_test.py +++ /dev/null @@ -1,92 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Uncertainty Baselines Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for uncertainty_metrics.""" - -from absl.testing import absltest -from absl.testing import parameterized -import jax.numpy as jnp -from uncertainty_metrics import reduce_2dmap # local file import from experimental.robust_segvit -from uncertainty_metrics import SegmentationUncertaintyMetrics # local file import from experimental.robust_segvit - - -class UncertaintyMetricsTest(parameterized.TestCase): - - def setUp(self): - super(UncertaintyMetricsTest, self).setUp() - self.targets = jnp.asarray([[[1, 2, 5, 7], [6, 4, 3, 3], [10, 9, 5, 0], - [8, 6, 4, 4]]]) - - self.preds = jnp.asarray([[[1, 2, 4, 7], [5, 6, 3, 3], [10, 9, 4, 0], - [8, 7, 3, 4]]]) - - self.unc_map = jnp.asarray([[[0.1, 0.3, 0.6, 0.3], [0.7, 0.6, 0.2, 0.1], - [0.2, 0.4, 0.5, 0.3], [0.1, 0.7, 0.6, 0.2]]]) - - # create logit map from unc_map by mapping entropy vals (0.1,0.7) - # to a feasible range of logit vals:(4.1, 6.2) - self.logit_map = 4.1 + (0.7 - self.unc_map) * (6.2 - 4.1) / (0.7 - 0.1) - - self.window_size = 2 - self.accuracy_th = 0.5 - self.uncertainty_th = 0.4 - - # true values - self.true_binary_acc_map = jnp.asarray([[[0, 1], [1, 0]]]) - self.true_binary_unc_map = jnp.asarray([[[1, 0], [0, 0]]]) - self.true_p_accurate_certain = jnp.asarray([0.67]) - self.true_p_uncertain_innacurate = jnp.asarray([0.5]) - self.true_pavpu = jnp.asarray([0.75]) - - # construct logits passed as input from unc_map - self.num_classes = 11 - self.img_size = 4 - - true_mask = jnp.arange(self.img_size * self.img_size - ) * self.num_classes + self.preds.flatten() - logits = jnp.zeros((self.img_size * self.img_size * self.num_classes)) - logits = logits.at[true_mask].set(self.logit_map.flatten()) - self.logits = jnp.expand_dims( - logits.reshape((self.img_size, self.img_size, self.num_classes)), 0) - - def test_setup(self): - preds_logits = jnp.argmax(self.logits, -1) - self.assertTrue(jnp.array_equal(self.preds, preds_logits)) - - def test_calculate_pacc_cert(self): - segment_unc = SegmentationUncertaintyMetrics( - logits=self.logits, - labels=self.targets, - window_size=self.window_size, - accuracy_th=self.accuracy_th, - uncertainty_th=self.uncertainty_th) - - self.assertEqual(self.true_pavpu, segment_unc.pavpu) - self.assertAlmostEqual(self.true_p_accurate_certain, segment_unc.pacc_cert, - 2) - self.assertAlmostEqual(self.true_p_uncertain_innacurate, - segment_unc.puncert_inacc, 2) - - @parameterized.parameters((1), (2), (3)) - def test_reduce_2dmap(self, batch_size): - array_map = jnp.repeat(jnp.ones((1, 4, 4)), batch_size, axis=0) - true_binary_map = jnp.repeat(jnp.ones((1, 2, 2)), batch_size, axis=0) - binary_map = reduce_2dmap(array_map, self.window_size, self.accuracy_th) - - self.assertTrue(jnp.array_equal(true_binary_map, binary_map)) - - -if __name__ == '__main__': - absltest.main()