diff --git a/experimental/robust_segvit/README.md b/experimental/robust_segvit/README.md
index 315313b7b..9e0218e5d 100644
--- a/experimental/robust_segvit/README.md
+++ b/experimental/robust_segvit/README.md
@@ -1,16 +1,25 @@
# Robust segvit
-*Robust_segvit* is a codebase to evaluate the robustness of semantic segmentation models.
+**Robust_segvit** is a codebase to evaluate the robustness of semantic segmentation models. The code is built on top of [uncertainty_baselines](https://github.com/google/uncertainty-baselines) and [Scenic](https://github.com/google-research/scenic).
-Robust_segvit is developed in [JAX](https://github.com/google/jax) and uses [Flax](https://github.com/google/flax), [uncertainty_baselines](https://github.com/google/uncertainty-baselines) and [Scenic](https://github.com/google-research/scenic).
+## Installation
+Robust_segvit is developed in [JAX](https://github.com/google/jax)/[Flax](https://github.com/google/flax).
-## Code structure
-See uncertainty_baselines/google/experimental/cityscapes.
+To run the code:
+1. Install [uncertainty_baselines](https://github.com/google/uncertainty-baselines).
+2. Install [Scenic](https://github.com/google-research/scenic).
+3. Follow the instructions for a toy run in [./run_deterministic_mac.sh]().
+## Datasets
+The experiment configurations for the different datasets are in:
-## Cityscapes
+- configs/cityscapes: Cityscapes dataset.
+- configs/ade20k_ind: ADE20k_ind dataset.
+- configs/street_hazards: Street Hazards dataset.
-We investigate the performance of different reliability methods on image segmentation tasks.
+## Comments:
+- The checkpoint used for finetuning is the same the original segmenter model: [vit_large_patch16_384](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)
-[x] configs/cityscapes: contains experiment configurations for the cityscapes dataset.
+## Citing work:
+If you reference this code, please cite [our paper](https://github.com/google/uncertainty-baselines).
\ No newline at end of file
diff --git a/experimental/robust_segvit/___init__.py b/experimental/robust_segvit/___init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/experimental/robust_segvit/checkpoint_utils.py b/experimental/robust_segvit/checkpoint_utils.py
new file mode 100644
index 000000000..eeb59d054
--- /dev/null
+++ b/experimental/robust_segvit/checkpoint_utils.py
@@ -0,0 +1,71 @@
+# load checkpoints
+from scenic.train_lib_deprecated import train_utils
+from pretrainer_utils import convert_torch_to_jax_checkpoint # local file import from experimental.robust_segvit
+from scenic.train_lib_deprecated import pretrain_utils
+from pretrainer_utils import convert_vision_transformer_to_scenic # local file import from experimental.robust_segvit
+
+
+def load_checkpoints_eval(config, model, train_state, workdir):
+ checkpoint_configs = config.get('checkpoint_configs', False)
+ if checkpoint_configs:
+ # Load torch weights
+ if 'torch' in checkpoint_configs.checkpoint_format:
+
+ bb_train_state = convert_torch_to_jax_checkpoint(
+ checkpoint_path=checkpoint_configs.checkpoint_path,
+ config=checkpoint_configs)
+
+ train_state = model.init_backbone_from_train_state(
+ train_state,
+ bb_train_state,
+ config,
+ checkpoint_configs
+ )
+ del bb_train_state
+
+ # Load weights in checkpoint_path or workdir
+ else:
+ checkpoint_path = checkpoint_configs.get('checkpoint_path', workdir)
+ train_state, _ = train_utils.restore_checkpoint(
+ checkpoint_path, train_state)
+ return train_state
+
+
+def load_checkpoints_backbone(config, model, train_state, workdir):
+ del workdir
+ # TODO(kellybuchanan): check out partial loader in
+ # https://github.com/google/uncertainty-baselines/commit/083b1dcc52bb1964f8917d15552ece8848d582ae#
+ restored_model_cfg = config.get('pretrained_backbone_configs')
+
+ # Load pretrained backbone
+ if restored_model_cfg.checkpoint_format in ('ub', 'big_vision', 'scenic'):
+ # load params from checkpoint
+ bb_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint(
+ checkpoint_path=restored_model_cfg.checkpoint_path,
+ convert_to_linen=False)
+
+ train_state = model.init_backbone_from_train_state(
+ train_state,
+ bb_train_state,
+ config,
+ restored_model_cfg,
+ model_prefix_path=['backbone'])
+ # Free unnecessary memory.
+ del bb_train_state
+ # Loader from scenic
+ elif restored_model_cfg.checkpoint_format in ('vision_transformer'):
+ # load params from checkpoint
+ bb_train_state = convert_vision_transformer_to_scenic(checkpoint_path=restored_model_cfg.checkpoint_path, convert_to_linen=False)
+
+ train_state = model.init_backbone_from_train_state(
+ train_state,
+ bb_train_state,
+ config,
+ restored_model_cfg,
+ model_prefix_path=['backbone'])
+
+ # Free unnecessary memory.
+ del bb_train_state
+ else:
+ raise NotImplementedError('')
+ return train_state
diff --git a/experimental/robust_segvit/configs/ade20k_ind/be.py b/experimental/robust_segvit/configs/ade20k_ind/be.py
index 9e149362a..e0410ec30 100644
--- a/experimental/robust_segvit/configs/ade20k_ind/be.py
+++ b/experimental/robust_segvit/configs/ade20k_ind/be.py
@@ -22,6 +22,8 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
@@ -40,21 +42,24 @@
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (640, 640)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', # pylint: disable=g-long-lambda
+
}
@@ -180,9 +185,25 @@ def get_config(runlocal=''):
config.eval_label_shift = False
config.model.input_shape = target_size
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
if runlocal:
config.count_flops = False
config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
config.batch_size = 8
config.num_training_epochs = 5
config.warmup_steps = 0
diff --git a/experimental/robust_segvit/configs/ade20k_ind/be_eval.py b/experimental/robust_segvit/configs/ade20k_ind/be_eval.py
index f70890536..df0aa2d3a 100644
--- a/experimental/robust_segvit/configs/ade20k_ind/be_eval.py
+++ b/experimental/robust_segvit/configs/ade20k_ind/be_eval.py
@@ -20,6 +20,8 @@
# pylint: enable=line-too-long
import ml_collections
+import datetime
+import os
_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
@@ -42,14 +44,14 @@
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
-EXPERIMENTID = '43838358-2'
+EXPERIMENTID = '45349725-1'
target_size = (640, 640)
# Upstream
CHECKPOINT_PATHS = {
- ('ub', 'L', 16, None, 'token', '43838358-2'):
- 'gs://ub-ekb/checkpoints_to_upload/ade20k/43838358-2',
+ ('ub', 'L', 16, None, 'token', '45349725-1'):
+ 'gs://ub-checkpoints/45349725-ade20k_ind_segmenter_be/1',
}
@@ -162,17 +164,33 @@ def get_config(runlocal=''):
config.eval_mode = True
config.eval_configs = ml_collections.ConfigDict()
config.eval_configs.mode = 'standard'
- config.eval_covariate_shift = True
- config.eval_label_shift = True
config.model.input_shape = target_size
+ config.eval_configs.store_logits = False
# Eval parameters for robustness
config.eval_label_shift = True
config.eval_covariate_shift = True
config.eval_robustness_configs = ml_collections.ConfigDict()
config.eval_robustness_configs.auc_online = True
- config.eval_robustness_configs.method_name = 'msp'
- config.eval_robustness_configs.num_top_k = 5
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
if runlocal:
config.count_flops = False
@@ -183,12 +201,6 @@ def get_config(runlocal=''):
config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
config.num_train_examples = TRAIN_SAMPLES
- else:
- # Load checkpoint
- config.checkpoint_configs = ml_collections.ConfigDict()
- config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
- config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
- config.checkpoint_configs.classifier = 'token'
return config
diff --git a/experimental/robust_segvit/configs/ade20k_ind/deterministic.py b/experimental/robust_segvit/configs/ade20k_ind/deterministic.py
index 9472fd60b..33883665a 100644
--- a/experimental/robust_segvit/configs/ade20k_ind/deterministic.py
+++ b/experimental/robust_segvit/configs/ade20k_ind/deterministic.py
@@ -22,6 +22,8 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
@@ -40,21 +42,23 @@
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (640, 640)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}
@@ -174,9 +178,25 @@ def get_config(runlocal=''):
config.eval_label_shift = False
config.model.input_shape = target_size
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
if runlocal:
config.count_flops = False
config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
config.batch_size = 8
config.num_training_epochs = 5
config.warmup_steps = 0
diff --git a/experimental/robust_segvit/configs/ade20k_ind/deterministic_eval.py b/experimental/robust_segvit/configs/ade20k_ind/deterministic_eval.py
new file mode 100644
index 000000000..e5f8ff97e
--- /dev/null
+++ b/experimental/robust_segvit/configs/ade20k_ind/deterministic_eval.py
@@ -0,0 +1,243 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Eval segmenter model on ade20k_ind.
+
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE
+}
+
+# Model specs.
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (640, 640)
+
+CHECKPOINT_ORIGIN = 'ub'
+EXPERIMENTID = '45373386-1'
+
+# Upstream
+CHECKPOINT_PATHS = {
+ ('ub', 'L', 16, None, 'token', '45373386-1'):
+ 'gs://ub-checkpoints/45373386-ade20k_ind_deterministic/1',
+}
+
+
+CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'ade20k_ind_deterministic_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'ade20k_ind'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = True
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.model.input_shape = target_size
+ config.eval_configs.store_logits = False
+
+ # Eval parameters for robustness
+ config.eval_label_shift = True
+ config.eval_covariate_shift = True
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ return hyper.product([])
+
diff --git a/experimental/robust_segvit/configs/ade20k_ind/deterministic_seeds.py b/experimental/robust_segvit/configs/ade20k_ind/deterministic_seeds.py
new file mode 100644
index 000000000..d7e7468ef
--- /dev/null
+++ b/experimental/robust_segvit/configs/ade20k_ind/deterministic_seeds.py
@@ -0,0 +1,253 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on ade20k_ind.
+
+Compare performance across seeds.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE
+}
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (640, 640)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'ade20k_ind_deterministic_seeds'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'ade20k_ind'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 3e-5
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = False
+ config.eval_covariate_shift = False
+ config.eval_label_shift = False
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for doing grid search."""
+
+ seeds = hyper.sweep('config.rng_seed', range(0, 5))
+
+
+ return hyper.product([seeds])
diff --git a/experimental/robust_segvit/configs/ade20k_ind/gp.py b/experimental/robust_segvit/configs/ade20k_ind/gp.py
index fe3c3186d..bd022cbc6 100644
--- a/experimental/robust_segvit/configs/ade20k_ind/gp.py
+++ b/experimental/robust_segvit/configs/ade20k_ind/gp.py
@@ -22,6 +22,8 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
@@ -40,21 +42,23 @@
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (640, 640)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}
@@ -187,9 +191,25 @@ def get_config(runlocal=''):
config.eval_label_shift = False
config.model.input_shape = target_size
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
if runlocal:
config.count_flops = False
config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
config.batch_size = 8
config.num_training_epochs = 5
config.warmup_steps = 0
diff --git a/experimental/robust_segvit/configs/ade20k_ind/gp_eval.py b/experimental/robust_segvit/configs/ade20k_ind/gp_eval.py
new file mode 100644
index 000000000..f505258d7
--- /dev/null
+++ b/experimental/robust_segvit/configs/ade20k_ind/gp_eval.py
@@ -0,0 +1,255 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Eval segmenter model on ade20k_ind.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE
+}
+
+# Model specs.
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (640, 640)
+
+CHECKPOINT_ORIGIN = 'ub'
+EXPERIMENTID='45350699-1'
+# Upstream
+CHECKPOINT_PATHS = {
+ ('ub', 'L', 16, None, 'token', '45350699-1'):
+ 'gs://ub-checkpoints/45350699-ade20k_ind_segmenter_gp/1',
+}
+
+
+CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'ade20k_ind_segmenter_gp_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'ade20k_ind'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'gp'
+
+ # GP layer params
+ config.model.decoder.gp_layer = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1.
+ # Disable momentum in order to use exact covariance update for finetuning.
+ # Disable to allow exact cov update.
+ config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99
+ config.model.decoder.mean_field_factor = 1.
+ # Additional params
+ config.model.decoder.gp_layer.normalize_input = True
+ config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1.
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 3e-5
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = True
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.model.input_shape = target_size
+ config.eval_configs.store_logits = False
+
+ # Eval parameters for robustness
+ config.eval_label_shift = True
+ config.eval_covariate_shift = True
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ return hyper.product([])
+
diff --git a/experimental/robust_segvit/configs/ade20k_ind/gp_seeds.py b/experimental/robust_segvit/configs/ade20k_ind/gp_seeds.py
new file mode 100644
index 000000000..5c69a5728
--- /dev/null
+++ b/experimental/robust_segvit/configs/ade20k_ind/gp_seeds.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on ade20k_ind.
+
+Compare performance across seeds.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE
+}
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (640, 640)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'ade20k_ind_segmenter_gp_seeds'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'ade20k_ind'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'gp'
+
+ # GP layer params
+ config.model.decoder.gp_layer = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1.
+ # Disable momentum in order to use exact covariance update for finetuning.
+ # Disable to allow exact cov update.
+ config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99
+ config.model.decoder.mean_field_factor = 9.
+ # Additional params
+ config.model.decoder.gp_layer.normalize_input = True
+ config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1.
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 3e-5
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = False
+ config.eval_covariate_shift = False
+ config.eval_label_shift = False
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for doing grid search."""
+
+ seeds = hyper.sweep('config.rng_seed', range(0, 5))
+
+ return hyper.product([seeds])
+
+
diff --git a/experimental/robust_segvit/configs/ade20k_ind/het.py b/experimental/robust_segvit/configs/ade20k_ind/het.py
index 52ed0f2dd..b85e701ff 100644
--- a/experimental/robust_segvit/configs/ade20k_ind/het.py
+++ b/experimental/robust_segvit/configs/ade20k_ind/het.py
@@ -22,6 +22,8 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
@@ -40,21 +42,23 @@
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (640, 640)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}
@@ -187,9 +191,25 @@ def get_config(runlocal=''):
config.eval_label_shift = False
config.model.input_shape = target_size
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
if runlocal:
config.count_flops = False
config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
config.batch_size = 8
config.num_training_epochs = 5
config.warmup_steps = 0
diff --git a/experimental/robust_segvit/configs/ade20k_ind/het_eval.py b/experimental/robust_segvit/configs/ade20k_ind/het_eval.py
index 66f06f8e4..a8385762d 100644
--- a/experimental/robust_segvit/configs/ade20k_ind/het_eval.py
+++ b/experimental/robust_segvit/configs/ade20k_ind/het_eval.py
@@ -22,6 +22,8 @@
# pylint: enable=line-too-long
import ml_collections
+import datetime
+import os
_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
@@ -44,14 +46,14 @@
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
-EXPERIMENTID = '43838062-14'
+EXPERIMENTID = '45350817-1'
target_size = (640, 640)
# Upstream
CHECKPOINT_PATHS = {
- ('ub', 'L', 16, None, 'token', '43838062-14'):
- 'gs://ub-ekb/checkpoints_to_upload/ade20k/43838062-14',
+ ('ub', 'L', 16, None, 'token', '45350817-1'):
+ 'gs://ub-checkpoints/45350817-ade20k_ind_segmenter_het_hyper/1',
}
@@ -116,12 +118,12 @@ def get_config(runlocal=''):
# Het layer params
# temp: wide sweep [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0]
- config.model.decoder.temperature = 2.0
+ config.model.decoder.temperature = 1.0
# efficient low rank approx ~ FxK where K is the classes. False for K<20.
config.model.decoder.param_efficient = False
# F as a low rank approx of KxK matrix has num_factors:
# imagenet~15, jft~50, cifar~6, cityscapes~sweep(5-10).
- config.model.decoder.num_factors = 10
+ config.model.decoder.num_factors = 5
# mc_samples: use as much as can be afforded, ideally > 10.
config.model.decoder.mc_samples = 1000
config.model.decoder.return_locs = False
@@ -172,17 +174,34 @@ def get_config(runlocal=''):
config.eval_mode = True
config.eval_configs = ml_collections.ConfigDict()
config.eval_configs.mode = 'standard'
- config.eval_covariate_shift = True
- config.eval_label_shift = True
config.model.input_shape = target_size
+ config.eval_configs.store_logits = False
# Eval parameters for robustness
config.eval_label_shift = True
config.eval_covariate_shift = True
config.eval_robustness_configs = ml_collections.ConfigDict()
config.eval_robustness_configs.auc_online = True
- config.eval_robustness_configs.method_name = 'msp'
- config.eval_robustness_configs.num_top_k = 5
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
if runlocal:
config.count_flops = False
@@ -193,12 +212,6 @@ def get_config(runlocal=''):
config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
config.num_train_examples = TRAIN_SAMPLES
- else:
- # Load checkpoint
- config.checkpoint_configs = ml_collections.ConfigDict()
- config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
- config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
- config.checkpoint_configs.classifier = 'token'
return config
diff --git a/experimental/robust_segvit/configs/ade20k_ind/toy_model.py b/experimental/robust_segvit/configs/ade20k_ind/toy_model.py
index cfc274317..55dbdfdd7 100644
--- a/experimental/robust_segvit/configs/ade20k_ind/toy_model.py
+++ b/experimental/robust_segvit/configs/ade20k_ind/toy_model.py
@@ -20,6 +20,8 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
@@ -135,8 +137,23 @@ def get_config(runlocal=''):
config.eval_mode = False
config.eval_configs = ml_collections.ConfigDict()
config.eval_configs.mode = 'standard'
- config.eval_covariate_shift = False
- config.eval_label_shift = False
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
if runlocal:
config.count_flops = False
diff --git a/experimental/robust_segvit/configs/cityscapes/be.py b/experimental/robust_segvit/configs/cityscapes/be.py
index 6f123e6db..2bf6586f6 100644
--- a/experimental/robust_segvit/configs/cityscapes/be.py
+++ b/experimental/robust_segvit/configs/cityscapes/be.py
@@ -22,27 +22,31 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_TRAIN_SIZE = 2975
_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (768, 768)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}
@@ -157,6 +161,23 @@ def get_config(runlocal=''):
config.eval_configs.mode = 'standard'
config.eval_covariate_shift = True
config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
if runlocal:
config.count_flops = False
diff --git a/experimental/robust_segvit/configs/cityscapes/be_eval.py b/experimental/robust_segvit/configs/cityscapes/be_eval.py
index 38f30d6b0..72b6b63e5 100644
--- a/experimental/robust_segvit/configs/cityscapes/be_eval.py
+++ b/experimental/robust_segvit/configs/cityscapes/be_eval.py
@@ -14,14 +14,14 @@
# limitations under the License.
# pylint: disable=line-too-long
-r"""Train segmenter model on cityscapes dataset.
-
-Compare performance from deterministic upstream checkpoints.
+r"""Evaluate segmenter_be model on cityscapes dataset.
"""
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_TRAIN_SIZE = 2975
_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
@@ -33,12 +33,12 @@
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (768, 768)
-EXPERIMENTID = '43838585-16'
+EXPERIMENTID = '45338505-1'
# Upstream
CHECKPOINT_PATHS = {
- ('ub', 'L', 16, None, 'token', '43838585-16'):
- 'gs://ub-ekb/checkpoints_to_upload/cityscapes/43838585-16',
+ ('ub', 'L', 16, None, 'token', '45338505-1'):
+ 'gs://ub-checkpoints/45338505-cityscapes_segmenter_be/1',
}
@@ -66,11 +66,15 @@ def get_config(runlocal=''):
config.experiment_name = 'cityscapes_segmenter_be_eval'
# Dataset.
- config.dataset_name = 'cityscapes'
+ config.dataset_name = 'robust_segvit_segmentation'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.target_size = (1024, 2048)
config.dataset_configs.train_split = 'train'
- config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
# Model.
config.model_name = 'segvit'
@@ -143,8 +147,7 @@ def get_config(runlocal=''):
config.eval_configs = ml_collections.ConfigDict()
config.eval_configs.mode = 'segmm'
config.eval_configs.window_stride = 512
- config.eval_covariate_shift = True
- config.eval_label_shift = True
+ config.eval_configs.store_logits = False
config.model.input_shape = target_size
# Eval parameters for robustness
@@ -152,8 +155,25 @@ def get_config(runlocal=''):
config.eval_covariate_shift = True
config.eval_robustness_configs = ml_collections.ConfigDict()
config.eval_robustness_configs.auc_online = True
- config.eval_robustness_configs.method_name = 'msp'
- config.eval_robustness_configs.num_top_k = 5
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
if runlocal:
config.count_flops = False
@@ -164,12 +184,6 @@ def get_config(runlocal=''):
config.dataset_configs.train_split = 'train[:5%]'
config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
'batch_size')
- else:
- # Load checkpoint
- config.checkpoint_configs = ml_collections.ConfigDict()
- config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
- config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
- config.checkpoint_configs.classifier = 'token'
return config
diff --git a/experimental/robust_segvit/configs/cityscapes/deterministic.py b/experimental/robust_segvit/configs/cityscapes/deterministic.py
index e3d249b94..56b9c290c 100644
--- a/experimental/robust_segvit/configs/cityscapes/deterministic.py
+++ b/experimental/robust_segvit/configs/cityscapes/deterministic.py
@@ -22,27 +22,31 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_TRAIN_SIZE = 2975
_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (768, 768)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}
@@ -70,11 +74,15 @@ def get_config(runlocal=''):
config.experiment_name = 'cityscapes_segmenter_pretrained'
# Dataset.
- config.dataset_name = 'cityscapes'
+ config.dataset_name = 'robust_segvit_segmentation'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.target_size = target_size
config.dataset_configs.train_split = 'train'
- config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
# Model.
config.model_name = 'segvit'
@@ -151,6 +159,22 @@ def get_config(runlocal=''):
config.eval_configs.mode = 'standard'
config.eval_covariate_shift = True
config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
if runlocal:
config.count_flops = False
@@ -211,8 +235,8 @@ def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
def get_sweep(hyper):
"""Defines the hyper-parameters sweeps for doing grid search."""
checkpoints = hyper.chainit([
- checkpoint(hyper, 'big_vision', 'L', 16, None, 'token',
- 'i21k+imagenet2012'),
+ checkpoint(hyper, 'vision_transformer', 'L', 16, None, 'token',
+ 'augreg+i21k+imagenet2012'),
])
epochs = hyper.sweep('config.num_training_epochs', [50, 100, 300])
diff --git a/experimental/robust_segvit/configs/cityscapes/deterministic_eval.py b/experimental/robust_segvit/configs/cityscapes/deterministic_eval.py
new file mode 100644
index 000000000..363761709
--- /dev/null
+++ b/experimental/robust_segvit/configs/cityscapes/deterministic_eval.py
@@ -0,0 +1,232 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Eval segmenter model trained on cityscapes dataset.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_TRAIN_SIZE = 2975
+_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
+
+# Model specs.
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (768, 768)
+
+
+CHECKPOINT_ORIGIN = 'ub'
+EXPERIMENTID = '45337813-1'
+
+# Upstream
+CHECKPOINT_PATHS = {
+ ('ub', 'L', 16, None, 'token', '45337813-1'):
+ 'gs://ub-checkpoints/45337813-cityscapes_segmenter_pretrained/1',
+}
+
+
+CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for Cityscapes segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'cityscapes_segmenter_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = (1024, 2048)
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 64
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref(
+ 'batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = True
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'segmm'
+ config.eval_configs.window_stride = 512
+ config.eval_configs.store_logits = False
+ config.model.input_shape = target_size
+
+ # Eval parameters for robustness
+ config.eval_label_shift = True
+ config.eval_covariate_shift = True
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.target_size = (128, 128)
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = 'train[:5%]'
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
+ 'batch_size')
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{'size': (stride, stride)}]))
+
+ if vit_size == 'B':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ return hyper.product([])
diff --git a/experimental/robust_segvit/configs/cityscapes/deterministic_seeds.py b/experimental/robust_segvit/configs/cityscapes/deterministic_seeds.py
new file mode 100644
index 000000000..c3c0de27b
--- /dev/null
+++ b/experimental/robust_segvit/configs/cityscapes/deterministic_seeds.py
@@ -0,0 +1,242 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on cityscapes dataset.
+
+Compare performance across seeds.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_TRAIN_SIZE = 2975
+_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (768, 768)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for Cityscapes segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'cityscapes_segmenter_seeds'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 64
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref(
+ 'batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = False
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.target_size = (128, 128)
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = 'train[:5%]'
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
+ 'batch_size')
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{'size': (stride, stride)}]))
+
+ if vit_size == 'B':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for doing grid search."""
+
+ seeds = hyper.sweep('config.rng_seed', range(0, 5))
+
+
+ return hyper.product([seeds])
+
diff --git a/experimental/robust_segvit/configs/cityscapes/gp.py b/experimental/robust_segvit/configs/cityscapes/gp.py
index 29205f705..e9ac3c654 100644
--- a/experimental/robust_segvit/configs/cityscapes/gp.py
+++ b/experimental/robust_segvit/configs/cityscapes/gp.py
@@ -22,27 +22,31 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_TRAIN_SIZE = 2975
_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (768, 768)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}
@@ -70,11 +74,15 @@ def get_config(runlocal=''):
config.experiment_name = 'cityscapes_segmenter_gp_hyper'
# Dataset.
- config.dataset_name = 'cityscapes'
+ config.dataset_name = 'robust_segvit_segmentation'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.target_size = target_size
config.dataset_configs.train_split = 'train'
- config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
# Model.
config.model_name = 'segvit'
@@ -164,6 +172,22 @@ def get_config(runlocal=''):
config.eval_configs.mode = 'standard'
config.eval_covariate_shift = True
config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
if runlocal:
config.count_flops = False
diff --git a/experimental/robust_segvit/configs/cityscapes/gp_eval.py b/experimental/robust_segvit/configs/cityscapes/gp_eval.py
new file mode 100644
index 000000000..d5836360f
--- /dev/null
+++ b/experimental/robust_segvit/configs/cityscapes/gp_eval.py
@@ -0,0 +1,245 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Evaluate segmenter_gp model on cityscapes dataset.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_TRAIN_SIZE = 2975
+_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
+
+# Model specs.
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (768, 768)
+
+CHECKPOINT_ORIGIN = 'ub'
+EXPERIMENTID = '45338722-1'
+
+# Upstream
+CHECKPOINT_PATHS = {
+ ('ub', 'L', 16, None, 'token', '45338722-1'):
+ 'gs://ub-checkpoints/45338722-cityscapes_segmenter_gp_hyper/1',
+}
+
+
+CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for Cityscapes segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'cityscapes_segmenter_gp_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = (1024, 2048)
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'gp'
+
+ # GP layer params
+ config.model.decoder.gp_layer = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1.
+ # Disable momentum in order to use exact covariance update for finetuning.
+ # Disable to allow exact cov update.
+ config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99
+ config.model.decoder.mean_field_factor = 1.
+ # Additional params
+ config.model.decoder.gp_layer.normalize_input = True
+ config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1.
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 64
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref(
+ 'batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = True
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'segmm'
+ config.eval_configs.window_stride = 512
+ config.eval_configs.store_logits = False
+ config.model.input_shape = target_size
+
+ # Eval parameters for robustness
+ config.eval_label_shift = True
+ config.eval_covariate_shift = True
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.target_size = (128, 128)
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = 'train[:5%]'
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
+ 'batch_size')
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{'size': (stride, stride)}]))
+
+ if vit_size == 'B':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ return hyper.product([])
diff --git a/experimental/robust_segvit/configs/cityscapes/gp_seeds.py b/experimental/robust_segvit/configs/cityscapes/gp_seeds.py
new file mode 100644
index 000000000..f43d397b0
--- /dev/null
+++ b/experimental/robust_segvit/configs/cityscapes/gp_seeds.py
@@ -0,0 +1,260 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on cityscapes dataset.
+
+Compare performance across seeds.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_TRAIN_SIZE = 2975
+_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (768, 768)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for Cityscapes segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'cityscapes_segmenter_gp_seeds'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'gp'
+
+ # GP layer params
+ config.model.decoder.gp_layer = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1.
+ # Disable momentum in order to use exact covariance update for finetuning.
+ # Disable to allow exact cov update.
+ config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99
+ config.model.decoder.mean_field_factor = 3.
+ # Additional params
+ config.model.decoder.gp_layer.normalize_input = True
+ config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1.
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 64
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref(
+ 'batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = False
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.target_size = (128, 128)
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = 'train[:5%]'
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
+ 'batch_size')
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{'size': (stride, stride)}]))
+
+ if vit_size == 'B':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for doing grid search."""
+
+ parameters = [
+ hyper.sweep('config.model.decoder.gp_layer.normalize_input',
+ [True, False]),
+ hyper.sweep('config.model.decoder.mean_field_factor',
+ hyper.discrete(range(1, 10))),
+ hyper.sweep('config.model.decoder.gp_layer.hidden_kwargs.feature_scale',
+ [1.0, 2.0]),
+ ]
+
+ return hyper.product(parameters)
diff --git a/experimental/robust_segvit/configs/cityscapes/het.py b/experimental/robust_segvit/configs/cityscapes/het.py
index 45e263afe..7d6b8f552 100644
--- a/experimental/robust_segvit/configs/cityscapes/het.py
+++ b/experimental/robust_segvit/configs/cityscapes/het.py
@@ -22,27 +22,31 @@
# pylint: enable=line-too-long
import ml_collections
+import os
+import datetime
_CITYSCAPES_TRAIN_SIZE = 2975
_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
# Model specs.
LOAD_PRETRAINED_BACKBONE = True
-BACKBONE_ORIGIN = 'big_vision'
+BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (768, 768)
-UPSTREAM_TASK = 'i21k+imagenet2012'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
# Upstream
MODEL_PATHS = {
# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
- ('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
- 'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}
@@ -70,11 +74,15 @@ def get_config(runlocal=''):
config.experiment_name = 'cityscapes_segmenter_het_base'
# Dataset.
- config.dataset_name = 'cityscapes'
+ config.dataset_name = 'robust_segvit_segmentation'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.target_size = target_size
config.dataset_configs.train_split = 'train'
- config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
# Model.
config.model_name = 'segvit'
@@ -165,6 +173,22 @@ def get_config(runlocal=''):
config.eval_configs.mode = 'standard'
config.eval_covariate_shift = True
config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
if runlocal:
config.count_flops = False
diff --git a/experimental/robust_segvit/configs/cityscapes/het_eval.py b/experimental/robust_segvit/configs/cityscapes/het_eval.py
new file mode 100644
index 000000000..4038e5688
--- /dev/null
+++ b/experimental/robust_segvit/configs/cityscapes/het_eval.py
@@ -0,0 +1,249 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on cityscapes dataset.
+
+Compare performance from deterministic upstream checkpoints.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_TRAIN_SIZE = 2975
+_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
+
+# Model specs.
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (768, 768)
+
+
+CHECKPOINT_ORIGIN = 'ub'
+EXPERIMENTID = '45338794-1'
+
+# Upstream
+CHECKPOINT_PATHS = {
+ ('ub', 'L', 16, None, 'token', '45338794-1'):
+ 'gs://ub-checkpoints/45338794-cityscapes_segmenter_het_base/1',
+}
+
+
+CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for Cityscapes segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'cityscapes_segmenter_het_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = (1024, 2048)
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'het'
+
+ # Het layer params
+ # temp: wide sweep [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0]
+ config.model.decoder.temperature = 1.0
+ # efficient low rank approx ~ FxK where K is the classes. False for K<20.
+ config.model.decoder.param_efficient = False
+ # F as a low rank approx of KxK matrix has num_factors:
+ # imagenet~15, jft~50, cifar~6, cityscapes~sweep(5-10).
+ config.model.decoder.num_factors = 5
+ # mc_samples: use as much as can be afforded, ideally > 10.
+ config.model.decoder.mc_samples = 1000
+ config.model.decoder.return_locs = False
+ # turn on to run an approx on KHW x KHW instead of KxK.
+ config.model.decoder.share_samples_across_batch = False
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 64
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref(
+ 'batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = True
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'segmm'
+ config.eval_configs.window_stride = 512
+ config.eval_configs.store_logits = False
+ config.model.input_shape = target_size
+
+ # Eval parameters for robustness
+ config.eval_label_shift = True
+ config.eval_covariate_shift = True
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.target_size = (128, 128)
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = 'train[:5%]'
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
+ 'batch_size')
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{'size': (stride, stride)}]))
+
+ if vit_size == 'B':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ return hyper.product([])
+
diff --git a/experimental/robust_segvit/configs/cityscapes/torch_eval.py b/experimental/robust_segvit/configs/cityscapes/torch_eval.py
new file mode 100644
index 000000000..532e9fe21
--- /dev/null
+++ b/experimental/robust_segvit/configs/cityscapes/torch_eval.py
@@ -0,0 +1,213 @@
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_TRAIN_SIZE = 2975
+_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
+
+# Model specs.
+CHECKPOINT_ORIGIN = 'torch-segmm'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (768, 768)
+EXPERIMENTID = 'torch-segmm-1'
+
+# Upstream
+CHECKPOINT_PATHS = {
+ ('torch-segmm', 'L', 16, None, 'token', 'torch-segmm-1'):
+ 'gs://ub-ekb/seg_l16_linear/checkpoint_model.npy',
+}
+
+
+CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for Cityscapes segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'cityscapes_segmenter_torch_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = (1024, 2048)
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 64
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE // config.get_ref(
+ 'batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = True
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'segmm'
+ config.eval_configs.window_stride = 512
+ config.model.input_shape = target_size
+
+ # Eval parameters for robustness
+ config.eval_label_shift = True
+ config.eval_covariate_shift = True
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'nmlogit'
+ config.eval_robustness_configs.num_top_k = 1
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.target_size = (128, 128)
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = 'train[:5%]'
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
+ 'batch_size')
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{'size': (stride, stride)}]))
+
+ if vit_size == 'B':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(
+ hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(
+ hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.checkpoint_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.checkpoint_configs.checkpoint_path', [
+ CHECKPOINT_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the parameters used to compare multiple metrics during eval."""
+
+ checkpoints = hyper.chainit([
+ checkpoint(hyper, 'ub', 'L', 16, None, 'token', 'torch-segmm-1'),
+ ])
+
+ return hyper.product([checkpoints])
diff --git a/experimental/robust_segvit/configs/cityscapes/toy_model.py b/experimental/robust_segvit/configs/cityscapes/toy_model.py
index 6a5e50e80..ba2147c5d 100644
--- a/experimental/robust_segvit/configs/cityscapes/toy_model.py
+++ b/experimental/robust_segvit/configs/cityscapes/toy_model.py
@@ -20,7 +20,8 @@
# pylint: enable=line-too-long
import ml_collections
-
+import os
+import datetime
batch_size = 128
_CITYSCAPES_TRAIN_SIZE_SPLIT = 146
@@ -43,11 +44,15 @@ def get_config(runlocal=''):
config.experiment_name = 'cityscapes_segmenter_toy_model'
# Dataset.
- config.dataset_name = 'cityscapes'
+ config.dataset_name = 'robust_segvit_segmentation'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.target_size = target_size
config.dataset_configs.train_split = 'train[:5%]'
- config.dataset_configs.dataset_name = '' # name of ood dataset to evaluate
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
# Model.
config.model_name = 'segvit'
@@ -118,6 +123,21 @@ def get_config(runlocal=''):
config.eval_covariate_shift = True
config.eval_label_shift = True
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
if runlocal:
config.count_flops = False
diff --git a/experimental/robust_segvit/configs/cityscapes/toy_model_eval.py b/experimental/robust_segvit/configs/cityscapes/toy_model_eval.py
new file mode 100644
index 000000000..83965106e
--- /dev/null
+++ b/experimental/robust_segvit/configs/cityscapes/toy_model_eval.py
@@ -0,0 +1,165 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Eval toy segmenter model on cityscapes.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+batch_size = 8
+_CITYSCAPES_TRAIN_SIZE_SPLIT = 16
+
+# Model spec.
+STRIDE = 4
+mlp_dim = 2
+num_heads = 1
+num_layers = 1
+hidden_size = 1
+target_size = (128, 128)
+
+# Upstream
+CHECKPOINT_ORIGIN = 'ub'
+VIT_SIZE = 'debug'
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+EXPERIMENTID = 'city_toy'
+
+CHECKPOINT_PATHS = {
+ ('ub', 'debug', 4, None, 'token', 'city_toy'):
+ 'ub-ekb/segmenter/cityscapes/toy_model/toy_model',
+}
+
+
+CHECKPOINT_PATH = CHECKPOINT_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+def get_config(runlocal=''):
+ """Returns the configuration for Cityscapes segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'cityscapes_segmenter_toy_model'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_split = 'train[:16]'
+ config.dataset_configs.validation_split = 'validation[:16]'
+ config.dataset_configs.name = 'cityscapes' # name of dataset to evaluate
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = 'gap'
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(2)
+ config.batch_size = batch_size
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.steps_per_epoch = _CITYSCAPES_TRAIN_SIZE_SPLIT // config.get_ref(
+ 'batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 0
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # init not included
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = True
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.eval_configs.store_logits = False
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+
+ return config
+
+
+def get_sweep(hyper):
+ return hyper.product([])
diff --git a/experimental/robust_segvit/configs/street_hazards/be.py b/experimental/robust_segvit/configs/street_hazards/be.py
new file mode 100644
index 000000000..2f6b9f7a9
--- /dev/null
+++ b/experimental/robust_segvit/configs/street_hazards/be.py
@@ -0,0 +1,265 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on street_hazards.
+
+Compare performance from deterministic upstream checkpoints.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+_STREET_HAZARDS_TRAIN_SIZE = 5125
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE,
+ 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE,
+
+}
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (720, 720)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', # pylint: disable=g-long-lambda
+
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for street hazards segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'street_hazards_segmenter_be'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'street_hazards'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit_be'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear_be'
+
+ # BE variables
+ config.model.backbone.ens_size = 3
+ config.model.backbone.random_sign_init = -0.5
+ config.model.backbone.be_layers = (22, 23)
+ config.fast_weight_lr_multiplier = 1.0
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = False
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for grid search."""
+
+ random_sign_init = hyper.sweep('config.model.backbone.random_sign_init',
+ [-0.5, 0.5])
+ fast_weight_lr_multiplier = hyper.sweep('config.fast_weight_lr_multiplier',
+ [0.5, 1.0, 2.0])
+
+ return hyper.product([random_sign_init, fast_weight_lr_multiplier])
diff --git a/experimental/robust_segvit/configs/street_hazards/deterministic.py b/experimental/robust_segvit/configs/street_hazards/deterministic.py
new file mode 100644
index 000000000..0c044eae9
--- /dev/null
+++ b/experimental/robust_segvit/configs/street_hazards/deterministic.py
@@ -0,0 +1,257 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on street_hazards.
+
+Compare performance from deterministic upstream checkpoints.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+_STREET_HAZARDS_TRAIN_SIZE = 5125
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE,
+ 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE
+
+}
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (720, 720)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 with augreg
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'street_hazards_deterministic'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'street_hazards'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = False
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for doing grid search."""
+
+ learning_rate = hyper.sweep('config.lr_configs.base_learning_rate',
+ [1e-4, 3e-4, 3e-5, 1e-5])
+
+ epochs = hyper.sweep('config.num_training_epochs', [100, 50, 200, 250])
+
+ return hyper.product([learning_rate, epochs])
diff --git a/experimental/robust_segvit/configs/street_hazards/deterministic_eval.py b/experimental/robust_segvit/configs/street_hazards/deterministic_eval.py
new file mode 100644
index 000000000..17e4493a5
--- /dev/null
+++ b/experimental/robust_segvit/configs/street_hazards/deterministic_eval.py
@@ -0,0 +1,247 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Evaluate segmenter model on street_hazards.
+
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+_STREET_HAZARDS_TRAIN_SIZE = 5125
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE,
+ 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE
+
+}
+
+# Model specs.
+target_size = (720, 720)
+
+LOAD_PRETRAINED_BACKBONE = True
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+target_size = (720, 720)
+
+CHECKPOINT_ORIGIN = 'ub'
+EXPERIMENTID='det_run1'
+
+# Upstream
+MODEL_PATHS = {
+ ('ub', 'L', 16, None, 'token', 'det_run1'):
+ 'gs://ub-ekb/segmenter/street_hazards/deterministic/deterministic_2022-09-27-07-32-08',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'street_hazards_deterministic_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'street_hazards'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = MODEL_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = True
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+ config.eval_configs.store_logits = False
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ return hyper.product([])
diff --git a/experimental/robust_segvit/configs/street_hazards/gp.py b/experimental/robust_segvit/configs/street_hazards/gp.py
new file mode 100644
index 000000000..ba4c8bc4e
--- /dev/null
+++ b/experimental/robust_segvit/configs/street_hazards/gp.py
@@ -0,0 +1,273 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on street_hazards.
+
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+_STREET_HAZARDS_TRAIN_SIZE = 5125
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE,
+ 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE
+
+}
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (720, 720)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384 with augreg
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'street_hazards_gp_hyper'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'street_hazards'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'gp'
+
+ # GP layer params
+ config.model.decoder.gp_layer = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1.
+ # Disable momentum in order to use exact covariance update for finetuning.
+ # Disable to allow exact cov update.
+ config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99
+ config.model.decoder.mean_field_factor = 1.
+ # Additional params
+ config.model.decoder.gp_layer.normalize_input = True
+ config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1.
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = False
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for doing grid search."""
+
+ parameters = [
+ hyper.sweep('config.model.decoder.gp_layer.normalize_input',
+ [True, False]),
+ hyper.sweep('config.model.decoder.mean_field_factor',
+ hyper.discrete(range(1, 10))),
+ hyper.sweep('config.model.decoder.gp_layer.hidden_kwargs.feature_scale',
+ [1.0, 2.0]),
+ ]
+
+ return hyper.product(parameters)
\ No newline at end of file
diff --git a/experimental/robust_segvit/configs/street_hazards/gp_eval.py b/experimental/robust_segvit/configs/street_hazards/gp_eval.py
new file mode 100644
index 000000000..6ab27e8f5
--- /dev/null
+++ b/experimental/robust_segvit/configs/street_hazards/gp_eval.py
@@ -0,0 +1,260 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Evaluate segmenter_gp model on street_hazards.
+
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+_STREET_HAZARDS_TRAIN_SIZE = 5125
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE,
+ 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE
+
+}
+
+# Model specs.
+target_size = (720, 720)
+
+LOAD_PRETRAINED_BACKBONE = True
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+target_size = (720, 720)
+
+CHECKPOINT_ORIGIN = 'ub'
+EXPERIMENTID='gp_run1'
+
+# Upstream
+MODEL_PATHS = {
+ ('ub', 'L', 16, None, 'token', 'gp_run1'):
+ 'gs://ub-ekb/segmenter/street_hazards/gp/gp_2022-10-03-15-05-54',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(CHECKPOINT_ORIGIN, VIT_SIZE, STRIDE,
+ RESNET_SIZE, CLASSIFIER, EXPERIMENTID)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for ADE20k_ind segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'street_hazards_gp_eval'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'street_hazards'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.0
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'gp'
+
+ # GP layer params
+ config.model.decoder.gp_layer = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.covmat_kwargs.ridge_penalty = 1.
+ # Disable momentum in order to use exact covariance update for finetuning.
+ # Disable to allow exact cov update.
+ config.model.decoder.gp_layer.covmat_kwargs.momentum = 0.99
+ config.model.decoder.mean_field_factor = 1.
+ # Additional params
+ config.model.decoder.gp_layer.normalize_input = True
+ config.model.decoder.gp_layer.hidden_kwargs = ml_collections.ConfigDict()
+ config.model.decoder.gp_layer.hidden_kwargs.feature_scale = 1.
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # Load checkpoint
+ config.checkpoint_configs = ml_collections.ConfigDict()
+ config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
+ config.checkpoint_configs.checkpoint_path = MODEL_PATH
+ config.checkpoint_configs.classifier = 'token'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = True
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+ config.eval_configs.store_logits = False
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ return hyper.product([])
diff --git a/experimental/robust_segvit/configs/street_hazards/het.py b/experimental/robust_segvit/configs/street_hazards/het.py
new file mode 100644
index 000000000..5493e2ad3
--- /dev/null
+++ b/experimental/robust_segvit/configs/street_hazards/het.py
@@ -0,0 +1,276 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train segmenter model on ade20k_ind.
+
+Compare performance from deterministic upstream checkpoints.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+_STREET_HAZARDS_TRAIN_SIZE = 5125
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE,
+ 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE,
+}
+
+# Model specs.
+LOAD_PRETRAINED_BACKBONE = True
+BACKBONE_ORIGIN = 'vision_transformer'
+VIT_SIZE = 'L'
+STRIDE = 16
+RESNET_SIZE = None
+CLASSIFIER = 'token'
+target_size = (720, 720)
+UPSTREAM_TASK = 'augreg+i21k+imagenet2012'
+
+
+# Upstream
+MODEL_PATHS = {
+
+ # Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
+ ('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
+ 'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
+ ('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
+ 'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+}
+
+
+MODEL_PATH = MODEL_PATHS[(BACKBONE_ORIGIN, VIT_SIZE, STRIDE, RESNET_SIZE,
+ CLASSIFIER, UPSTREAM_TASK)]
+
+if VIT_SIZE == 'B':
+ mlp_dim = 3072
+ num_heads = 12
+ num_layers = 12
+ hidden_size = 768
+elif VIT_SIZE == 'L':
+ mlp_dim = 4096
+ num_heads = 16
+ num_layers = 24
+ hidden_size = 1024
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for street hazards segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'street_hazards_ind_segmenter_het_hyper'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.name = 'street_hazards'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = CLASSIFIER
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'het'
+
+ # Het layer params
+ # temp: wide sweep [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0]
+ config.model.decoder.temperature = 1.0
+ # efficient low rank approx ~ FxK where K is the classes. False for K<20.
+ config.model.decoder.param_efficient = False
+ # F as a low rank approx of KxK matrix has num_factors:
+ # imagenet~15, jft~50, cifar~6, cityscapes~sweep(5-10).
+ config.model.decoder.num_factors = 5
+ # mc_samples: use as much as can be afforded, ideally > 10.
+ config.model.decoder.mc_samples = 1000
+ config.model.decoder.return_locs = False
+ # turn on to run an approx on KHW x KHW instead of KxK.
+ config.model.decoder.share_samples_across_batch = False
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(100)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 1 * config.get_ref('steps_per_epoch')
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 1e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # load pretrained backbone
+ config.load_pretrained_backbone = LOAD_PRETRAINED_BACKBONE
+ config.pretrained_backbone_configs = ml_collections.ConfigDict()
+ config.pretrained_backbone_configs.checkpoint_format = BACKBONE_ORIGIN
+ config.pretrained_backbone_configs.checkpoint_path = MODEL_PATH
+ config.pretrained_backbone_configs.token_init = True
+ config.pretrained_backbone_configs.classifier = 'token'
+ config.pretrained_backbone_configs.backbone_type = 'vit'
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_mode = False
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+ config.model.input_shape = target_size
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'msp'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.dataset_configs.train_target_size = (128, 128)
+ config.model.input_shape = config.dataset_configs.train_target_size
+ config.batch_size = 8
+ config.num_training_epochs = 5
+ config.warmup_steps = 0
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+
+ return config
+
+
+def checkpoint(hyper, backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task):
+ """Defines checkpoints for sweep."""
+ overwrites = []
+ if resnet_size is not None:
+ raise NotImplementedError('')
+ else:
+ overwrites.append(
+ hyper.sweep('config.model.patches', [{
+ 'size': (stride, stride)
+ }]))
+
+ if vit_size == 'B':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [3072]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [12]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [768]))
+ elif vit_size == 'L':
+ overwrites.append(hyper.sweep('config.model.backbone.mlp_dim', [4096]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_heads', [16]))
+ overwrites.append(hyper.sweep('config.model.backbone.num_layers', [24]))
+ overwrites.append(hyper.sweep('config.model.backbone.hidden_size', [1024]))
+ else:
+ raise NotImplementedError('')
+
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_format',
+ [backbone_origin]))
+ overwrites.append(
+ hyper.sweep('config.pretrained_backbone_configs.checkpoint_path', [
+ MODEL_PATHS[(backbone_origin, vit_size, stride, resnet_size,
+ classifier, upstream_task)]
+ ]))
+
+ return hyper.product(overwrites)
+
+
+def get_sweep(hyper):
+ """Defines the hyper-parameters sweeps for doing grid search."""
+ parameters = [
+ hyper.sweep('config.model.decoder.num_factors',
+ hyper.discrete([5, 10, 20, 50])),
+ hyper.sweep('config.model.decoder.temperature',
+ [0.15, 0.3, 0.5, 0.75, 1.0, 1.5, 2.0]),
+ hyper.sweep('config.model.decoder.share_samples_across_batch',
+ [True, False]),
+ hyper.sweep('config.model.decoder.param_efficient',
+ [True, False]),
+ ]
+
+ return hyper.product(parameters)
diff --git a/experimental/robust_segvit/configs/street_hazards/toy_model.py b/experimental/robust_segvit/configs/street_hazards/toy_model.py
new file mode 100644
index 000000000..0d251d923
--- /dev/null
+++ b/experimental/robust_segvit/configs/street_hazards/toy_model.py
@@ -0,0 +1,170 @@
+# coding=utf-8
+# Copyright 2022 The Uncertainty Baselines Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=line-too-long
+r"""Train toy model on toy street_hazards dataset calling the robust_segvit codebase.
+
+"""
+# pylint: enable=line-too-long
+
+import ml_collections
+import os
+import datetime
+
+_CITYSCAPES_FINE_TRAIN_SIZE = 2975
+_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
+
+_ADE20K_TRAIN_SIZE = 20210
+_PASCAL_VOC_TRAIN_SIZE = 10582
+_PASCAL_CONTEXT_TRAIN_SIZE = 4998
+_STREET_HAZARDS_TRAIN_SIZE = 5125
+
+TRAIN_SIZES = {
+ 'cityscapes': _CITYSCAPES_FINE_TRAIN_SIZE,
+ 'ade20k': _ADE20K_TRAIN_SIZE,
+ 'ade20k_ind': _ADE20K_TRAIN_SIZE,
+ 'pascal_voc': _PASCAL_VOC_TRAIN_SIZE,
+ 'pascal_context': _PASCAL_CONTEXT_TRAIN_SIZE,
+ 'street_hazards': _STREET_HAZARDS_TRAIN_SIZE
+}
+
+# Model spec.
+STRIDE = 4
+mlp_dim = 2
+num_heads = 1
+num_layers = 1
+hidden_size = 1
+target_size = (512, 512)
+
+TRAIN_SAMPLES = 32
+
+
+def get_config(runlocal=''):
+ """Returns the configuration for street_hazards segmentation."""
+
+ runlocal = bool(runlocal)
+
+ config = ml_collections.ConfigDict()
+ config.experiment_name = 'street_hazards_segmenter_ind_toy_model'
+
+ # Dataset.
+ config.dataset_name = 'robust_segvit_segmentation'
+ config.dataset_configs = ml_collections.ConfigDict()
+ config.dataset_configs.target_size = target_size
+ config.dataset_configs.train_target_size = config.dataset_configs.get_ref(
+ 'target_size')
+ config.dataset_configs.denoise = None
+ config.dataset_configs.use_timestep = 0
+ config.dataset_configs.train_split = 'train'
+ config.dataset_configs.validation_split = 'validation'
+ config.dataset_configs.name = 'street_hazards'
+ config.dataset_configs.dataset_name = '' # ood name flag to write in eval.
+
+ # Model.
+ config.model_name = 'segvit'
+ config.model = ml_collections.ConfigDict()
+
+ config.model.patches = ml_collections.ConfigDict()
+ config.model.patches.size = (STRIDE, STRIDE)
+
+ config.model.backbone = ml_collections.ConfigDict()
+ config.model.backbone.type = 'vit'
+ config.model.backbone.mlp_dim = mlp_dim
+ config.model.backbone.num_heads = num_heads
+ config.model.backbone.num_layers = num_layers
+ config.model.backbone.hidden_size = hidden_size
+ config.model.backbone.dropout_rate = 0.1
+ config.model.backbone.attention_dropout_rate = 0.0
+ config.model.backbone.classifier = 'token'
+
+ # Decoder
+ config.model.decoder = ml_collections.ConfigDict()
+ config.model.decoder.type = 'linear'
+
+ # Training.
+ config.trainer_name = 'segvit_trainer'
+ config.optimizer = 'adam'
+ config.optimizer_configs = ml_collections.ConfigDict()
+ config.l2_decay_factor = 0.0
+ config.max_grad_norm = 1.0
+ config.label_smoothing = None
+ config.num_training_epochs = ml_collections.FieldReference(2)
+ config.batch_size = 32
+ config.rng_seed = 0
+ config.focal_loss_gamma = 0.0
+
+ # Learning rate.
+ config.num_train_examples = TRAIN_SIZES.get(config.dataset_configs.name)
+ config.steps_per_epoch = config.get_ref(
+ 'num_train_examples') // config.get_ref('batch_size')
+ # setting 'steps_per_cycle' to total_steps basically means non-cycling cosine.
+ config.lr_configs = ml_collections.ConfigDict()
+ config.lr_configs.learning_rate_schedule = 'compound'
+ config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
+ config.lr_configs.warmup_steps = 0
+ config.lr_configs.steps_per_cycle = config.get_ref(
+ 'num_training_epochs') * config.get_ref('steps_per_epoch')
+ config.lr_configs.base_learning_rate = 5e-4
+
+ # model and data dtype
+ config.model_dtype_str = 'float32'
+ config.data_dtype_str = 'float32'
+
+ # init not included
+
+ # Logging.
+ config.write_summary = True
+ config.write_xm_measurements = True # write XM measurements
+ config.xprof = False # Profile using xprof.
+ config.checkpoint = True # Do checkpointing.
+ config.checkpoint_steps = 5 * config.get_ref('steps_per_epoch')
+
+ config.debug_train = False # Debug mode during training.
+ config.debug_eval = False # Debug mode during eval.
+ config.log_eval_steps = 1 * config.get_ref('steps_per_epoch')
+
+ # Evaluation.
+ config.eval_mode = False
+ config.eval_configs = ml_collections.ConfigDict()
+ config.eval_configs.mode = 'standard'
+ config.eval_covariate_shift = True
+ config.eval_label_shift = True
+
+ config.eval_robustness_configs = ml_collections.ConfigDict()
+ config.eval_robustness_configs.auc_online = True
+ config.eval_robustness_configs.method_name = 'mlogit'
+
+ # wandb.ai configurations.
+ config.use_wandb = False
+ config.wandb_dir = 'wandb'
+ config.wandb_project = 'rdl-debug'
+ config.wandb_entity = 'ekellbuch'
+ config.wandb_exp_name = None # Give experiment a name.
+ config.wandb_exp_name = (
+ os.path.splitext(os.path.basename(__file__))[0] + '_' +
+ datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
+ config.wandb_exp_group = None # Give experiment a group name.
+
+ if runlocal:
+ config.count_flops = False
+ config.batch_size = 8
+ config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
+ config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
+ config.num_train_examples = TRAIN_SAMPLES
+ return config
+
+
+def get_sweep(hyper):
+ return hyper.product([])
diff --git a/experimental/robust_segvit/custom_models.py b/experimental/robust_segvit/custom_models.py
index f4b2b68a0..0cab33f38 100644
--- a/experimental/robust_segvit/custom_models.py
+++ b/experimental/robust_segvit/custom_models.py
@@ -391,6 +391,11 @@ def global_unc_metrics_fn(
assert isinstance(all_unc_confusion_mats, list) # List of eval batches.
cm = np.sum(all_unc_confusion_mats, axis=0) # Sum over eval batches.
+ if cm.ndim == 2: # [batch_size, 4]
+ pass
+ elif cm.ndim == 3: # [num_devices, batch_size per device, 4]
+ cm = np.sum(cm, axis=0) # sum over devices
+
assert cm.ndim == 2, ('Expecting uncertainty confusion matrix to have shape '
'[batch_size, 4], got '
f'{cm.shape}.')
diff --git a/experimental/robust_segvit/custom_segmentation_trainer.py b/experimental/robust_segvit/custom_segmentation_trainer.py
index 086ab1ae9..a7d19690e 100644
--- a/experimental/robust_segvit/custom_segmentation_trainer.py
+++ b/experimental/robust_segvit/custom_segmentation_trainer.py
@@ -44,12 +44,18 @@
from ensemble_utils import log_average_softmax_probs # local file import from experimental.robust_segvit
from inference import process_batch # local file import from experimental.robust_segvit
from ood_metrics import get_ood_metrics # local file import from experimental.robust_segvit
-from ood_metrics import get_ood_score # local file import from experimental.robust_segvit
from pretrainer_utils import convert_torch_to_jax_checkpoint # local file import from experimental.robust_segvit
+from pretrainer_utils import convert_vision_transformer_to_scenic # local file import from experimental.robust_segvit
from uncertainty_metrics import get_uncertainty_confusion_matrix # local file import from experimental.robust_segvit
-
+from checkpoint_utils import load_checkpoints_eval
+from checkpoint_utils import load_checkpoints_backbone
+import h5py
+import os
import resource
import sys
+import robustness_metrics as rm
+from metrics_multihost import ComputeOODAUCMetric
+from metrics_multihost import host_all_gather_metrics
Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]],
@@ -126,6 +132,7 @@ def evaluate(train_state: train_utils.TrainState,
global_metrics_fn: Any,
global_unc_metrics_fn: Optional[Any],
prefix: str = 'valid',
+ workdir: str = '',
) -> Dict[str, Any]:
"""Model evaluator.
@@ -160,17 +167,54 @@ def evaluate(train_state: train_utils.TrainState,
# Evaluate global metrics on one of the hosts (lead_host), but given
# intermediate values collected from all hosts.
- for _ in range(steps_per_eval):
+ # setup calibration evaluation
+ ece_num_bins = config.get('ece_num_bins', 15)
+ ece_metric = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins)._metric
+ calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False)._metric
+
+ # store logits
+ store_logits = config.eval_configs.get('store_logits', False)
+
+ if store_logits:
+ store_logits_fname = os.path.join(workdir, "{}_{}_val.h5py".format(prefix,"logits"))
+ f = h5py.File(store_logits_fname, 'w', libver='latest')
+ f.swmr_mode = True # single write multi-read
+ input_shape = dataset.meta_data['input_shape'][1:3]
+ num_classes = dataset.meta_data['num_classes']
+ num_eval_examples = int(steps_per_eval * config.batch_size)
+ logits_out = f.create_dataset('logits', (num_eval_examples,) + input_shape + (num_classes,))
+ inputs_out = f.create_dataset('inputs', (num_eval_examples,) + input_shape + (3,))
+ labels_out = f.create_dataset('labels', (num_eval_examples,) + input_shape)
+
+ for step_ in range(steps_per_eval):
eval_batch = next(dataset.valid_iter)
e_batch, e_logits, e_metrics, confusion_matrix, unc_confusion_matrix = eval_step_pmapped(
train_state=train_state, batch=eval_batch)
eval_metrics.append(train_utils.unreplicate_and_get(e_metrics))
+
+ probs = jax.nn.softmax(e_logits, axis=-1)
+ # updates on each host separately
+ ece_metric.update_state(labels=e_batch['label'], probabilities=probs, sample_weight=e_batch['batch_mask'])
+ y_pred = jnp.argmax(probs, axis=-1) # predicted label indices
+ confidence = jnp.max(probs, axis=-1) # confidence score for predicted labels
+ calib_auc.update_state(y_true=e_batch['label'], y_pred=y_pred, confidence=confidence, sample_weight=e_batch['batch_mask'])
+
if lead_host and global_metrics_fn is not None:
# Collect data to be sent for computing global metrics.
eval_all_confusion_mats.append(to_cpu(confusion_matrix, all_gather=True))
eval_all_unc_confusion_mats.append(
to_cpu(unc_confusion_matrix, all_gather=True))
+ if store_logits:
+ start_idx = step_ * config.batch_size
+ end_idx = start_idx + config.batch_size
+ logits_out[start_idx:end_idx] = e_logits
+ inputs_out[start_idx:end_idx] = e_batch['inputs']
+ labels_out[start_idx:end_idx] = e_batch['label']
+
+ if store_logits:
+ f.close()
+
# Compute global metrics
eval_global_metrics_summary = {}
if lead_host and global_metrics_fn is not None:
@@ -190,6 +234,17 @@ def evaluate(train_state: train_utils.TrainState,
prefix=prefix,
)
+ del e_metrics, eval_batch, eval_metrics, eval_global_metrics_summary
+ del eval_all_confusion_mats
+ del eval_all_unc_confusion_mats
+
+ # Gather uncertainty metrics from all hosts and write value:
+ ece_metric = host_all_gather_metrics(ece_metric)
+ calib_auc = host_all_gather_metrics(calib_auc)
+ writer.write_scalars(step=step, scalars={'{}_ece'.format(prefix): ece_metric.result(),
+ '{}_calib_auc'.format(prefix): calib_auc.result(),
+ } )
+
# Visualize val predictions for one batch:
if lead_host:
# in eval_step we do not use all_gather in batch or logits
@@ -206,10 +261,7 @@ def evaluate(train_state: train_utils.TrainState,
writer.flush()
# Free some memory
- del eval_metrics
- del eval_global_metrics_summary
- del eval_all_confusion_mats
- del eval_all_unc_confusion_mats
+ del example_viz, images, e_batch, e_predictions, e_logits, logits
return eval_summary
@@ -222,6 +274,7 @@ def evaluate_ood(
writer: metric_writers.MetricWriter,
lead_host: Any,
prefix: str = 'valid',
+ workdir: str ='',
**kwargs,
) -> Dict[str, Any]:
"""Model evaluator.
@@ -250,89 +303,55 @@ def evaluate_ood(
auc_online = kwargs.pop('auc_online', False)
+ # store logits
+ store_logits = config.eval_configs.get('store_logits', False)
+
+ if store_logits:
+ store_logits_fname = os.path.join(workdir, "{}_{}_val.h5py".format(prefix,"logits"))
+ f = h5py.File(store_logits_fname, 'w', libver='latest')
+ f.swmr_mode = True # single write multi-read
+ input_shape = dataset.meta_data['input_shape'][1:3]
+ num_classes = dataset.meta_data['num_classes']
+ num_eval_examples = int(steps_per_eval * config.batch_size)
+ logits_out = f.create_dataset('logits', (num_eval_examples,) + input_shape + (num_classes,))
+ inputs_out = f.create_dataset('inputs', (num_eval_examples,) + input_shape + (3,))
+ labels_out = f.create_dataset('labels', (num_eval_examples,) + input_shape)
+
if auc_online:
# TODO(kellybuchanan): check split of data across devices.
# initialize metrics: ideally in each device in each host/process/machine
- # keras initializes one metric in each host because it runs in cpu
- # so we need to convert to jax to run metrics in each device in each host
-
- auc_pr = tf.keras.metrics.AUC(curve='PR')
- auc_roc = tf.keras.metrics.AUC(curve='ROC')
+ # keras initializes one metric in each host because it runs in cpu.
+ # so we need to convert the function to run metrics in each device/host.
+ auc_pr = ComputeOODAUCMetric(curve='PR', num_thresholds=100)
+ auc_roc = ComputeOODAUCMetric(curve='ROC', num_thresholds=100)
# Loop through each machine:
- for _ in range(steps_per_eval):
+ for step_ in range(steps_per_eval):
eval_batch = next(dataset.valid_iter)
e_batch, e_logits = eval_step_pmapped(
train_state=train_state, batch=eval_batch)
- # In eval_step_pmapped we have not used all gather, so each metric is in
- # each device and we should be able to compute devices separately
-
- ood_score = get_ood_score(e_logits, **kwargs)
-
- auc_roc.update_state(
- e_batch['label'], ood_score, sample_weight=e_batch['batch_mask'])
- auc_pr.update_state(
- e_batch['label'], ood_score, sample_weight=e_batch['batch_mask'])
-
- # How to communicate metrics across hosts?
- # Ideally we can collect auc_metrics per host, merge them, compute result.
- # However, we cannot pass arbitraty class.
- # jax which doesn't work with arbitrary objects
- # Here we write a custom merge_state as in tf.keras.metrics
- # by pulling states from tf.keras obj, combining them and putting them back
- # into a keras object using list of host's auc_roc objects.
-
- def keras_auc_to_arrays(keras_auc_object):
- """Pull out arrays from keras roc object."""
- # The thresholds used are determinisitc, so we need not store them.
- tp = jnp.asarray(keras_auc_object.true_positives)
- fp = jnp.asarray(keras_auc_object.false_positives)
- tn = jnp.asarray(keras_auc_object.true_negatives)
- fn = jnp.asarray(keras_auc_object.false_negatives)
- return tp, fp, tn, fn
-
- def arrays_to_keras_auc(tp, fp, tn, fn, keras_auc_object):
- """Assign confusion matrix arrays to a keras_auc_object."""
- keras_auc_object.true_positives.assign(tp)
- keras_auc_object.false_positives.assign(fp)
- keras_auc_object.true_negatives.assign(tn)
- keras_auc_object.false_negatives.assign(fn)
- return keras_auc_object
-
- auc_roc_state = keras_auc_to_arrays(auc_roc)
- auc_pr_state = keras_auc_to_arrays(auc_pr)
-
- def combine_states(all_auc_states):
- # jax can take in trees of arrays, tuple is considered a tree so we can
- # unpack it here.
- # each array here has dimensions #host x shape
-
- all_tp, all_fp, all_tn, all_fn = all_auc_states
-
- assert all_tp.shape == (jax.process_count(), 200)
- assert all_fp.shape == (jax.process_count(), 200)
- assert all_tn.shape == (jax.process_count(), 200)
- assert all_fn.shape == (jax.process_count(), 200)
-
- tp = jnp.sum(all_tp, 0)
- fp = jnp.sum(all_fp, 0)
- tn = jnp.sum(all_tn, 0)
- fn = jnp.sum(all_fn, 0)
-
- return tp, fp, tn, fn
-
- # Gather the data across all hosts.
- all_auc_roc_states = multihost_utils.process_allgather(auc_roc_state)
- all_auc_pr_states = multihost_utils.process_allgather(auc_pr_state)
-
- # Below we pick the first device.
- auc_roc = arrays_to_keras_auc(*combine_states(all_auc_roc_states), auc_roc)
- auc_pr = arrays_to_keras_auc(*combine_states(all_auc_pr_states), auc_pr)
-
- eval_summary = {'auroc': float(auc_roc.result().numpy()),
- 'auprc': float(auc_pr.result().numpy()),
+ if store_logits:
+ start_idx = step_ * config.batch_size
+ end_idx = start_idx + config.batch_size
+ logits_out[start_idx:end_idx] = e_logits
+ inputs_out[start_idx:end_idx] = e_batch['inputs']
+ labels_out[start_idx:end_idx] = e_batch['labels']
+
+ # In eval_step_pmapped we have not used all gather, so each metric is in each device
+ # and we should be able to compute metrics in devices separately.
+ auc_pr.calculate_and_update_scores(logits=e_logits, label=e_batch['label'],
+ sample_weight=e_batch['batch_mask'], **kwargs)
+ auc_roc.calculate_and_update_scores(logits=e_logits, label=e_batch['label'],
+ sample_weight=e_batch['batch_mask'], **kwargs)
+
+ if store_logits:
+ f.close()
+
+ eval_summary = {'auroc': auc_roc.gather_metrics(),
+ 'auprc': auc_pr.gather_metrics(),
}
+
else:
eval_logits = []
eval_ood_masks = []
@@ -344,7 +363,7 @@ def combine_states(all_auc_states):
e_batch, e_logits = eval_step_pmapped(
train_state=train_state, batch=eval_batch)
- # Store all logits in cpu
+ # Store all logits in cpu:
if lead_host:
e_batch = to_cpu(e_batch, all_gather=False)
e_logits = to_cpu(e_logits, all_gather=False)
@@ -363,6 +382,7 @@ def combine_states(all_auc_states):
ood_mask=eval_ood_labels,
weights=eval_ood_masks,
**kwargs)
+
############### LOG EVAL SUMMARY ###############
writer.write_scalars(
step, {
@@ -498,12 +518,14 @@ def training_loss_fn(params):
logits) # batch_size x h x w x num_classes
metrics = metrics_fn(logits, batch)
- new_train_state = train_state.replace( # pytype: disable=attribute-error
+ logits = jnp.argmax(logits, axis=-1)
+
+ train_state = train_state.replace( # pytype: disable=attribute-error
global_step=step + 1,
optimizer=new_optimizer,
model_state=new_model_state,
rng=new_rng)
- return new_train_state, metrics, lr, jnp.argmax(logits, axis=-1)
+ return train_state, metrics, lr, logits
def eval_step(
@@ -588,10 +610,12 @@ def eval_step(
# Collect predictions and batches from all hosts.
# use all_gather to copy and replicate across all hosts
# we skip doing this for batch and logits to save memory
+ # unless we want to store the logits
# predictions = jnp.argmax(logits, axis=-1)
# predictions = jax.lax.all_gather(predictions, 'batch')
- # logits = jax.lax.all_gather(logits, 'batch')
- # batch = jax.lax.all_gather(batch, 'batch')
+ if config.eval_configs.get('store_logits', False):
+ logits = jax.lax.all_gather(logits, 'batch')
+ batch = jax.lax.all_gather(batch, 'batch')
confusion_matrix = jax.lax.all_gather(confusion_matrix, 'batch')
unc_confusion_matrix = jax.lax.all_gather(unc_confusion_matrix, 'batch')
@@ -656,9 +680,10 @@ def eval_step_baseline(
# Collect predictions and batches from all hosts.
# use all_gather to copy and replicate across all hosts
# we can skip doing this for batch and logits to save memory
- # is the OOM in tpu or cpu?
- # batch = jax.lax.all_gather(batch, 'batch')
- # logits = jax.lax.all_gather(logits, 'batch')
+ # jis the OOM in tpu or cpu?
+ if config.eval_configs.get('store_logits', False):
+ batch = jax.lax.all_gather(batch, 'batch')
+ logits = jax.lax.all_gather(logits, 'batch')
return batch, logits
@@ -741,28 +766,7 @@ def train(
# Load pretrained backbone
if start_step == 0 and config.get('load_pretrained_backbone', False):
- # TODO(kellybuchanan): check out partial loader in
- # https://github.com/google/uncertainty-baselines/commit/083b1dcc52bb1964f8917d15552ece8848d582ae#
- restored_model_cfg = config.get('pretrained_backbone_configs')
-
- # Loader from scenic
- if restored_model_cfg.checkpoint_format in ('ub', 'big_vision', 'scenic'):
- # load params from checkpoint
- bb_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint(
- checkpoint_path=restored_model_cfg.checkpoint_path,
- convert_to_linen=False)
-
- train_state = model.init_backbone_from_train_state(
- train_state,
- bb_train_state,
- config,
- restored_model_cfg,
- model_prefix_path=['backbone'])
- # Free unnecessary memory.
- del bb_train_state
- else:
- raise NotImplementedError('')
-
+ train_state = load_checkpoints_backbone(config, model, train_state, workdir)
elif start_step == 0:
logging.info('Not restoring from any pretrained_backbone.')
@@ -816,7 +820,7 @@ def train(
checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps
train_metrics, extra_training_logs = [], []
- train_summary, eval_summary = None, None
+ train_summary, eval_summary = {}, {}
global_metrics_fn = model.get_global_metrics_fn() # pytype: disable=attribute-error
global_unc_metrics_fn = model.get_global_unc_metrics_fn() # pytype: disable=attribute-error
@@ -883,12 +887,13 @@ def train(
train_summary = train_utils.log_train_summary(
step=step,
- train_metrics=jax.tree_map(train_utils.unreplicate_and_get,
+ train_metrics=jax.tree_util.tree_map(train_utils.unreplicate_and_get,
train_metrics),
- extra_training_logs=jax.tree_map(train_utils.unreplicate_and_get,
+ extra_training_logs=jax.tree_util.tree_map(train_utils.unreplicate_and_get,
extra_training_logs),
writer=writer)
+ del example_viz, train_metrics, extra_training_logs
# Reset metric accumulation for next evaluation cycle.
train_metrics, extra_training_logs = [], []
@@ -908,6 +913,7 @@ def train(
lead_host=lead_host,
global_metrics_fn=global_metrics_fn,
global_unc_metrics_fn=global_unc_metrics_fn,
+ workdir=workdir,
)
# check accuracy for early stopping.
@@ -1047,26 +1053,7 @@ def eval_ckpt(
checkpoint_configs = config.get('checkpoint_configs', False)
if checkpoint_configs:
- # Load torch weights
- if 'torch' in checkpoint_configs.checkpoint_format:
-
- bb_train_state = convert_torch_to_jax_checkpoint(
- checkpoint_path=checkpoint_configs.checkpoint_path,
- config=checkpoint_configs)
-
- train_state = model.init_backbone_from_train_state(
- train_state,
- bb_train_state,
- config,
- checkpoint_configs
- )
- del bb_train_state
-
- # Load weights in checkpoint_path or workdir
- else:
- checkpoint_path = checkpoint_configs.get('checkpoint_path', workdir)
- train_state, _ = train_utils.restore_checkpoint(
- checkpoint_path, train_state)
+ train_state = load_checkpoints_eval(config, model, train_state, workdir)
else:
logging.info('Not loading any checkpoints')
@@ -1103,6 +1090,7 @@ def eval_ckpt(
global_metrics_fn=global_metrics_fn,
global_unc_metrics_fn=global_unc_metrics_fn,
prefix=prefix,
+ workdir=workdir,
)
# Wait until computations are done before running robustness evaluator.
@@ -1111,6 +1099,7 @@ def eval_ckpt(
# ----------------------------------------------------------------------------
# Evaluate OOD datasets
+ logging.info('Evaluating OOD datasets')
eval_summary_ood = evaluate_ood_step(
train_state=train_state,
config=config,
@@ -1153,7 +1142,7 @@ def evaluate_ood_step(
Returns:
eval_summary: summary evaluation
"""
- del workdir
+ eval_summary = {}
if config.get('eval_covariate_shift', False):
@@ -1168,12 +1157,13 @@ def evaluate_ood_step(
# We can donate the eval_batch's buffer.
)
- eval_summary = None
global_metrics_fn = model.get_global_metrics_fn() # pytype: disable=attribute-error
global_unc_metrics_fn = model.get_global_unc_metrics_fn() # pytype: disable=attribute-error
eval_ood_covariate = {'cityscapes_c': evaluate_cityscapes_c,
- 'ade20k_ind_c': evaluate_ade20k_corrupted,}
+ 'ade20k_ind_c': evaluate_ade20k_corrupted,
+ 'street_hazards_c': evaluate_street_hazards_corrupted,
+ }
# TODO(kellybuchanan): merge data sources.
# The form of the ind dataset name depends on the source of the data.
@@ -1189,22 +1179,28 @@ def evaluate_ood_step(
elif any('ade20k' in ind_name for ind_name in ind_names):
logging.info('Loading Ade20k_ind_c')
ood_dataset = 'ade20k_ind_c'
+ elif any('street' in ind_name for ind_name in ind_names):
+ logging.info('Loading street_hazards_c')
+ ood_dataset = 'street_hazards_c'
else:
logging.info('OOD Covariate shift dataset is not implemented')
+ ood_dataset = None
- eval_summary = eval_ood_covariate[ood_dataset](
- train_state=train_state,
- config=config,
- rng=rng,
- eval_step_pmapped=eval_step_pmapped,
- writer=writer,
- lead_host=lead_host,
- global_metrics_fn=global_metrics_fn,
- global_unc_metrics_fn=global_unc_metrics_fn,
- )
+ if ood_dataset:
+ eval_summary = eval_ood_covariate[ood_dataset](
+ train_state=train_state,
+ config=config,
+ rng=rng,
+ eval_step_pmapped=eval_step_pmapped,
+ writer=writer,
+ lead_host=lead_host,
+ global_metrics_fn=global_metrics_fn,
+ global_unc_metrics_fn=global_unc_metrics_fn,
+ workdir=workdir,
+ )
- # Wait until computations are done before exiting.
- jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
+ # Wait until computations are done before exiting.
+ jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
# ----------------------------------------------------------------------------
if config.get('eval_label_shift', False):
@@ -1221,7 +1217,8 @@ def evaluate_ood_step(
eval_label_shift = {
'fishyscapes': evaluate_fishyscapes,
- 'ade20k_ood_open': evaluate_ade20k_ood_open
+ 'ade20k_ood_open': evaluate_ade20k_ood_open,
+ 'street_hazards_ood_open': evaluate_street_hazards_ood_open,
}
# The form of the ind dataset name depends on the source of the data.
@@ -1234,25 +1231,29 @@ def evaluate_ood_step(
if any('cityscapes' in ind_name for ind_name in ind_names):
logging.info('Loading Fishyscapes...')
ood_dataset = 'fishyscapes'
-
- if any('ade20k' in ind_name for ind_name in ind_names):
+ elif any('ade20k' in ind_name for ind_name in ind_names):
logging.info('Loading ADE20k OOD OPEN...')
ood_dataset = 'ade20k_ood_open'
-
+ elif any('street_hazards' in ind_name for ind_name in ind_names):
+ logging.info('Loading StreetHazards OPEN...')
+ ood_dataset = 'street_hazards_ood_open'
else:
logging.info('OOD Label shift dataset is not implemented')
+ ood_dataset = None
- eval_summary = eval_label_shift[ood_dataset](
- train_state=train_state,
- config=config,
- rng=rng,
- eval_step_pmapped=eval_step_ood_pmapped,
- writer=writer,
- lead_host=lead_host,
- )
+ if ood_dataset:
+ eval_summary = eval_label_shift[ood_dataset](
+ train_state=train_state,
+ config=config,
+ rng=rng,
+ eval_step_pmapped=eval_step_ood_pmapped,
+ writer=writer,
+ lead_host=lead_host,
+ workdir=workdir,
+ )
- # Wait until computations are done before exiting.
- jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
+ # Wait until computations are done before exiting.
+ jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
return eval_summary
@@ -1265,6 +1266,7 @@ def evaluate_cityscapes_c(
lead_host: Any,
global_metrics_fn: Any,
global_unc_metrics_fn: Any,
+ workdir: str = None,
) -> Dict[str, Any]:
"""Evaluate cityscapes-c dataset.
@@ -1291,21 +1293,21 @@ def evaluate_cityscapes_c(
# update config:
ood_config = ml_collections.ConfigDict()
ood_config.update(**config)
- ood_config.update({'dataset_name': 'cityscapes_variants'})
+ ood_config.update({'dataset_name': 'robust_segvit_variants'})
accuracy_per_corruption = {}
- prefix = 'citycvalid'
- for corruption in cityscapes_variants.CITYSCAPES_C_CORRUPTIONS:
+ prefix = 'cityc'
+ for corruption in datasets_info.CITYSCAPES_C_CORRUPTIONS:
local_list = [] # list to compute macro average per corruption
- for severity in cityscapes_variants.CITYSCAPES_C_SEVERITIES:
+ for severity in datasets_info.CITYSCAPES_C_SEVERITIES:
with ood_config.unlocked():
- ood_config.dataset_configs.dataset_name = f'cityscapes_corrupted/semantic_segmentation_{corruption}_{severity}'
+ ood_config.dataset_configs.name = f'cityscapes_c_{corruption}_{severity}'
rng, data_rng = jax.random.split(rng)
dataset = train_utils.get_dataset(ood_config, data_rng)
dataset.meta_data['dataset_name'] = 'cityscapes_c'
- dataset.meta_data['prefix'] = prefix + f'_{corruption}_{severity}'
+ dataset.meta_data['prefix'] = prefix + f'/{corruption}/{severity}/valid'
eval_summary = evaluate(
train_state=train_state,
@@ -1318,6 +1320,7 @@ def evaluate_cityscapes_c(
global_metrics_fn=global_metrics_fn,
global_unc_metrics_fn=global_unc_metrics_fn,
prefix=dataset.meta_data['prefix'],
+ workdir=workdir,
)
local_list.append(eval_summary)
@@ -1331,7 +1334,7 @@ def evaluate_cityscapes_c(
# append name to metrics
key_separator = '_'
avg_cityscapes_c_metrics = {
- key_separator.join((prefix, key)): val
+ key_separator.join((prefix + '/valid', key)): val
for key, val in cityscapes_c_metrics.items()
}
# update metrics
@@ -1348,6 +1351,7 @@ def evaluate_fishyscapes(
eval_step_pmapped: Any,
writer: metric_writers.MetricWriter,
lead_host: Any,
+ workdir: str = '',
) -> Dict[str, Any]:
"""Evaluate Fishyscapes dataset.
@@ -1372,21 +1376,21 @@ def evaluate_fishyscapes(
# update config:
ood_config = ml_collections.ConfigDict()
ood_config.update(**config)
- ood_config.update({'dataset_name': 'cityscapes_variants'})
+ ood_config.update({'dataset_name': 'robust_segvit_variants'})
device_count = jax.device_count()
accuracy_per_corruption = {}
- prefix = 'fishyvalid'
- for corruption in cityscapes_variants.FISHYSCAPES_CORRUPTIONS:
+ prefix = 'fishyscapes'
+ for corruption in datasets_info.FISHYSCAPES_CORRUPTIONS:
with ood_config.unlocked():
- ood_config.dataset_configs.dataset_name = f'fishyscapes/{corruption}'
+ ood_config.dataset_configs.name = f'fishyscapes/{corruption}'
ood_config.batch_size = device_count
data_rng, rng = jax.random.split(rng)
dataset = train_utils.get_dataset(ood_config, data_rng)
dataset.meta_data['dataset_name'] = 'fishyscapes'
- dataset.meta_data['prefix'] = prefix + f'_{corruption}'
+ dataset.meta_data['prefix'] = prefix + f'/{corruption}/valid'
eval_summary = evaluate_ood(
train_state=train_state,
@@ -1397,6 +1401,7 @@ def evaluate_fishyscapes(
writer=writer,
lead_host=lead_host,
prefix=dataset.meta_data['prefix'],
+ workdir=workdir,
**config.get('eval_robustness_configs', {}),
)
@@ -1408,7 +1413,7 @@ def evaluate_fishyscapes(
# append name to metrics
key_separator = '_'
avg_fishyscapes_metrics = {
- key_separator.join((prefix, key)): val
+ key_separator.join((prefix +'/valid', key)): val
for key, val in fishyscapes_metrics.items()
}
# update metrics
@@ -1425,6 +1430,7 @@ def evaluate_ade20k_ood_open(
eval_step_pmapped: Any,
writer: metric_writers.MetricWriter,
lead_host: Any,
+ workdir: str = '',
) -> Dict[str, Any]:
"""Evaluate ADE20k OOD dataset.
@@ -1452,7 +1458,7 @@ def evaluate_ade20k_ood_open(
ood_config.update({'dataset_name': 'robust_segvit_segmentation'})
device_count = jax.device_count()
- prefix = 'ade20k_ood_open'
+ prefix = 'ade20k_ood_open/valid'
with ood_config.unlocked():
ood_config.dataset_configs.name = 'ade20k_ood_open'
@@ -1471,10 +1477,11 @@ def evaluate_ade20k_ood_open(
writer=writer,
lead_host=lead_host,
prefix=dataset.meta_data['prefix'],
+ workdir=workdir,
**config.get('eval_robustness_configs', {}),
)
- # append name to metrics
+ # append name to metrics:
key_separator = '_'
avg_open_set_metrics = {
key_separator.join((prefix, key)): val
@@ -1497,6 +1504,7 @@ def evaluate_ade20k_corrupted(
lead_host: Any,
global_metrics_fn: Any,
global_unc_metrics_fn: Any,
+ workdir : str,
) -> Dict[str, Any]:
"""Evaluate Ade20k-C dataset.
@@ -1530,14 +1538,14 @@ def evaluate_ade20k_corrupted(
prefix = 'ade20k_ind_c'
for corruption in datasets_info.ADE20K_C_CORRUPTIONS:
local_list = [] # list to compute macro average per corruption
- for severity in range(1, 6):
+ for severity in datasets_info.ADE20K_C_SEVERITIES:
with ood_config.unlocked():
ood_config.dataset_configs.name = f'ade20k_ind_c_{corruption}_{severity}'
data_rng, rng = jax.random.split(rng)
dataset = train_utils.get_dataset(ood_config, data_rng)
- dataset.meta_data['prefix'] = prefix + f'_{corruption}_{severity}'
+ dataset.meta_data['prefix'] = prefix + f'/{corruption}/{severity}/valid'
eval_summary = evaluate(
train_state=train_state,
@@ -1550,6 +1558,7 @@ def evaluate_ade20k_corrupted(
global_metrics_fn=global_metrics_fn,
global_unc_metrics_fn=global_unc_metrics_fn,
prefix=dataset.meta_data['prefix'],
+ workdir=workdir,
)
local_list.append(eval_summary)
@@ -1563,7 +1572,94 @@ def evaluate_ade20k_corrupted(
# append name to metrics
key_separator = '_'
avg_corrupted_metrics = {
- key_separator.join((prefix, key)): val
+ key_separator.join((prefix + '/valid', key)): val
+ for key, val in ade20k_c_metrics.items()
+ }
+ # update metrics
+ eval_summary.update(avg_corrupted_metrics)
+ writer.write_scalars(0, avg_corrupted_metrics)
+ writer.flush()
+ return eval_summary
+
+
+def evaluate_street_hazards_corrupted(
+ train_state: train_utils.TrainState,
+ config: ml_collections.ConfigDict,
+ rng: Any,
+ eval_step_pmapped: Any,
+ writer: metric_writers.MetricWriter,
+ lead_host: Any,
+ global_metrics_fn: Any,
+ global_unc_metrics_fn: Any,
+ workdir : str,
+) -> Dict[str, Any]:
+ """Evaluate StreetHazards-C dataset.
+
+ Args:
+ train_state: train state.
+ config: experiment configuration.
+ rng: jax rng.
+ eval_step_pmapped: eval state
+ writer: CLU metrics writer instance.
+ lead_host: Evaluate global metrics on one of the hosts (lead_host) given
+ intermediate values collected from all hosts.
+ global_metrics_fn: global metrics to evaluate.
+ global_unc_metrics_fn: global uncertainty metrics to evaluate.
+ Returns:
+ eval_summary: summary evaluation
+ """
+ # Load dataset
+ # set resource limit to debug in mac osx
+ # (see https://github.com/tensorflow/datasets/issues/1441)
+ if jax.process_index() == 0 and sys.platform == 'darwin':
+ low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (low, high))
+
+ # update config:
+ ood_config = ml_collections.ConfigDict()
+ ood_config.update(**config)
+ ood_config.update({'dataset_name': 'robust_segvit_variants'})
+
+ # Calculate metrics per corruption.
+ accuracy_per_corruption = {}
+ prefix = 'street_hazards_c'
+ for corruption in datasets_info.STREETHAZARDS_C_CORRUPTIONS:
+ local_list = [] # list to compute macro average per corruption
+ for severity in datasets_info.STREETHAZARDS_C_SEVERITIES:
+
+ with ood_config.unlocked():
+ ood_config.dataset_configs.name = f'street_hazards_c_{corruption}_{severity}'
+
+ data_rng, rng = jax.random.split(rng)
+ dataset = train_utils.get_dataset(ood_config, data_rng)
+ dataset.meta_data['prefix'] = prefix + f'/{corruption}/{severity}/valid'
+
+ eval_summary = evaluate(
+ train_state=train_state,
+ dataset=dataset,
+ config=ood_config,
+ step=0,
+ eval_step_pmapped=eval_step_pmapped,
+ writer=writer,
+ lead_host=lead_host,
+ global_metrics_fn=global_metrics_fn,
+ global_unc_metrics_fn=global_unc_metrics_fn,
+ prefix=dataset.meta_data['prefix'],
+ workdir=workdir,
+ )
+
+ local_list.append(eval_summary)
+
+ accuracy_per_corruption[corruption] = eval_utils.average_list_of_dicts(
+ local_list)
+
+ ade20k_c_metrics = eval_utils.average_list_of_dicts(
+ accuracy_per_corruption.values())
+
+ # append name to metrics
+ key_separator = '_'
+ avg_corrupted_metrics = {
+ key_separator.join((prefix + '/valid', key)): val
for key, val in ade20k_c_metrics.items()
}
# update metrics
@@ -1571,3 +1667,76 @@ def evaluate_ade20k_corrupted(
writer.write_scalars(0, avg_corrupted_metrics)
writer.flush()
return eval_summary
+
+
+def evaluate_street_hazards_ood_open(
+ train_state: train_utils.TrainState,
+ config: ml_collections.ConfigDict,
+ rng: Any,
+ eval_step_pmapped: Any,
+ writer: metric_writers.MetricWriter,
+ lead_host: Any,
+ workdir: str,
+) -> Dict[str, Any]:
+ """Evaluate StreetHazards OOD dataset.
+
+ Args:
+ train_state: train state.
+ config: experiment configuration.
+ rng: jax rng.
+ eval_step_pmapped: eval state
+ writer: CLU metrics writer instance.
+ lead_host: Evaluate global metrics on one of the hosts (lead_host) given
+ intermediate values collected from all hosts.
+
+ Returns:
+ eval_summary: summary evaluation
+ """
+ # set resource limit to debug in mac osx
+ # (see https://github.com/tensorflow/datasets/issues/1441)
+ if jax.process_index() == 0 and sys.platform == 'darwin':
+ low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (low, high))
+
+ # update config:
+ ood_config = ml_collections.ConfigDict()
+ ood_config.update(**config)
+ ood_config.update({'dataset_name': 'robust_segvit_segmentation'})
+
+ device_count = jax.device_count()
+ prefix = 'street_hazards_open/valid'
+
+ with ood_config.unlocked():
+ ood_config.dataset_configs.name = 'street_hazards_open'
+ ood_config.batch_size = device_count
+
+ data_rng, rng = jax.random.split(rng)
+ dataset = train_utils.get_dataset(ood_config, data_rng)
+ dataset.meta_data['prefix'] = prefix
+
+ eval_summary = evaluate_ood(
+ train_state=train_state,
+ dataset=dataset,
+ config=ood_config,
+ step=0,
+ eval_step_pmapped=eval_step_pmapped,
+ writer=writer,
+ lead_host=lead_host,
+ prefix=dataset.meta_data['prefix'],
+ workdir=workdir,
+ **config.get('eval_robustness_configs', {}),
+ )
+
+ # append name to metrics
+ key_separator = '_'
+ avg_open_set_metrics = {
+ key_separator.join((prefix, key)): val
+ for key, val in eval_summary.items()
+ }
+ # update metrics
+ eval_summary.update(avg_open_set_metrics)
+ writer.write_scalars(0, avg_open_set_metrics)
+ writer.flush()
+
+ return eval_summary
+
diff --git a/experimental/robust_segvit/custom_segmentation_trainer_test.py b/experimental/robust_segvit/custom_segmentation_trainer_test.py
index 90b6c7553..5a6d06d42 100644
--- a/experimental/robust_segvit/custom_segmentation_trainer_test.py
+++ b/experimental/robust_segvit/custom_segmentation_trainer_test.py
@@ -34,7 +34,7 @@
from sklearn import metrics as sk_metrics
import tensorflow as tf
import custom_segmentation_trainer # local file import from experimental.robust_segvit
-
+import custom_models
class SegmentationTrainerTest(parameterized.TestCase):
"""Tests the default trainer on single device setup."""
@@ -226,6 +226,61 @@ def test_get_confusion_matrix(self, seed, masked_fraction):
self.assertAlmostEqual(metrics_dict['mean_iou'], miou_np, places=4)
+ @parameterized.parameters([(0, 0.0), (1, 0.01), (2, 0.5), (3, 0.99), (4, 1)])
+ def test_unc_confusion_matrix(self, seed, masked_fraction):
+ """Test computation of mIoU metric."""
+ np.random.seed(seed)
+
+ # Create test data:
+ num_classes = 3
+ input_shape = [8, 1, 224, 224]
+ logits_shape = input_shape + [num_classes]
+ logits_np = np.random.rand(*logits_shape)
+ logits = jnp.array(logits_np)
+
+ # when the uncertainty threshold is 100% or = 0
+ # all labels are certain, and pavpu is the fraction of patches that are accurate.
+ uncertainty_th = 0.0
+ window_size = 1
+
+ # Note: We include label -1, which indicates excluded pixels:
+ label = np.random.randint(0, num_classes, size=input_shape)
+ label[:4] = np.argmax(logits_np[:4], axis=-1) # Set half to correct.
+
+ batch_np = {
+ 'label':
+ label,
+ 'batch_mask':
+ (np.random.rand(*input_shape) > masked_fraction) & (label != -1),
+ }
+ batch = {
+ 'label': jnp.array(batch_np['label']),
+ 'batch_mask': jnp.array(batch_np['batch_mask']),
+ }
+
+ cm_pmapped = jax.pmap(
+ functools.partial(
+ custom_segmentation_trainer.get_uncertainty_confusion_matrix,
+ uncertainty_th=uncertainty_th,
+ window_size=window_size,
+ uncertainty_measure='softmax',
+ ), axis_name='batch')
+
+ unc_confusion_matrix = [
+ cm_pmapped(labels=labels, logits=logits_, weights=masks)
+ for labels, logits_, masks in
+ zip(batch['label'], logits, batch['batch_mask'])]
+ unc_confusion_matrix = jax.device_get(jax_utils.unreplicate(unc_confusion_matrix))
+ metrics_dict = custom_models.global_unc_metrics_fn(
+ unc_confusion_matrix)
+ labels_negative_ignored = np.maximum(batch_np['label'], 0)
+ y_pred = np.argmax(logits_np, axis=-1)
+ weights = batch_np['batch_mask']
+ accurate = labels_negative_ignored == y_pred
+ pavpu = np.sum(accurate * weights) / np.sum(weights)
+ # all labels are certain, pavpu = fraction of patches that are accurate
+ self.assertAlmostEqual(metrics_dict['pacc_cert'], jnp.nan_to_num(pavpu), places=2)
+
if __name__ == '__main__':
absltest.main()
diff --git a/experimental/robust_segvit/metrics_multihost.py b/experimental/robust_segvit/metrics_multihost.py
new file mode 100644
index 000000000..ae082fb60
--- /dev/null
+++ b/experimental/robust_segvit/metrics_multihost.py
@@ -0,0 +1,133 @@
+"""Calculate ood metrics across hosts.
+
+# How to communicate metrics across hosts?
+# Ideally we can collect auc_metrics per host, merge them, compute result.
+# However, we cannot pass arbitraty class.
+# jax which doesn't work with arbitrary objects.
+
+
+"""
+from typing import Any, Optional, Dict
+
+import jax
+import jax.numpy as jnp
+import tensorflow as tf
+from jax.experimental import multihost_utils
+from ood_metrics import get_ood_score
+from ood_metrics import get_score
+import numpy as np
+import copy
+
+
+# Here we write a custom merge_state as in tf.keras.metrics
+# by pulling states from tf.keras obj, combining them and putting them back
+# into a keras object using list of host's auc_roc objects.
+def keras_auc_to_arrays(keras_auc_object):
+ """Pull out arrays from keras roc object."""
+ # The thresholds used are determinisitc, so we need not store them.
+ tp = jnp.asarray(keras_auc_object.true_positives)
+ fp = jnp.asarray(keras_auc_object.false_positives)
+ tn = jnp.asarray(keras_auc_object.true_negatives)
+ fn = jnp.asarray(keras_auc_object.false_negatives)
+ return tp, fp, tn, fn
+
+
+def arrays_to_keras_auc(tp, fp, tn, fn, keras_auc_object):
+ """Assign confusion matrix arrays to a keras_auc_object."""
+ keras_auc_object.true_positives.assign(tp)
+ keras_auc_object.false_positives.assign(fp)
+ keras_auc_object.true_negatives.assign(tn)
+ keras_auc_object.false_negatives.assign(fn)
+ return keras_auc_object
+
+
+def combine_states(all_auc_states, num_thresholds=200):
+ # jax can take in trees of arrays, tuple is considered a tree so we can
+ # unpack it here.
+ # each array here has dimensions #host x shape
+
+ all_tp, all_fp, all_tn, all_fn = all_auc_states
+
+ assert all_tp.shape == (jax.process_count(), num_thresholds)
+ assert all_fp.shape == (jax.process_count(), num_thresholds)
+ assert all_tn.shape == (jax.process_count(), num_thresholds)
+ assert all_fn.shape == (jax.process_count(), num_thresholds)
+
+ tp = jnp.sum(all_tp, 0)
+ fp = jnp.sum(all_fp, 0)
+ tn = jnp.sum(all_tn, 0)
+ fn = jnp.sum(all_fn, 0)
+
+ return tp, fp, tn, fn
+
+
+def host_all_gather_metrics(metric):
+ states = multihost_utils.process_allgather(metric.get_weights())
+ state = jax.tree_util.tree_map(lambda x: np.sum(x, axis=0), states)
+ metric_copy = copy.deepcopy(metric)
+ metric_copy.set_weights(state)
+ return metric_copy
+
+
+class ComputeAUCMetric:
+ """Calculate auc metrics across multiple hosts."""
+ def __init__(self, curve, num_thresholds=200, from_logits=False):
+ self.curve = curve
+ self.num_thresholds = num_thresholds
+ self.from_logits = from_logits
+ self.auc = tf.keras.metrics.AUC(curve=self.curve,
+ from_logits=self.from_logits,
+ num_thresholds=self.num_thresholds)
+
+ def calculate_and_update_scores(self, logits, label, sample_weight):
+ self.auc.update_state(label, logits, sample_weight=sample_weight)
+
+ def gather_metrics(self):
+ auc_state = keras_auc_to_arrays(self.auc)
+
+ # Gather the data across all hosts.
+ all_auc_states = multihost_utils.process_allgather(auc_state)
+
+ # Below we pick the first device.
+ self.auc = arrays_to_keras_auc(*combine_states(all_auc_states,
+ num_thresholds=self.num_thresholds),
+ self.auc)
+
+ return self.auc.result()
+
+
+class ComputeOODAUCMetric:
+ """Calculate auc metrics across multiple hosts.
+
+ Args:
+ curve: 'ROC' or 'PR' for the type of AUC.
+ num_thresholds: Number of thresholds to use for discretizing the roc curve.
+ from_logits: Whether `y_pred` is expected to be a logits tensor. If it is a logits tensor,
+ a sigmoid function is applied to the logits.
+
+ """
+ def __init__(self, curve, num_thresholds=200):
+ self.curve = curve
+ self.num_thresholds = num_thresholds
+ self.from_logits = False
+ self.auc = tf.keras.metrics.AUC(curve=self.curve,
+ from_logits=self.from_logits,
+ num_thresholds=self.num_thresholds)
+
+ def calculate_and_update_scores(self, logits, label, sample_weight, **kwargs):
+ ood_score = get_ood_score(logits, **kwargs)
+
+ # skip images where all the pixels are ood or there are no ood pixels
+ all_pixel_ood = jnp.sum(label*sample_weight) == 1
+ no_pixel_ood = jnp.sum(label*sample_weight) == 0
+
+ if not(all_pixel_ood) and not(no_pixel_ood):
+ self.auc.update_state(label, ood_score, sample_weight=sample_weight)
+
+ def gather_metrics(self):
+
+ # Gather the metrics:
+ self.auc = host_all_gather_metrics(self.auc)
+
+ return self.auc.result()
+
diff --git a/experimental/robust_segvit/metrics_multihost_test.py b/experimental/robust_segvit/metrics_multihost_test.py
new file mode 100644
index 000000000..fb46d62fd
--- /dev/null
+++ b/experimental/robust_segvit/metrics_multihost_test.py
@@ -0,0 +1,142 @@
+from absl.testing import absltest
+from absl.testing import parameterized
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax import jax_utils
+
+from metrics_multihost import ComputeAUCMetric
+from metrics_multihost import ComputeOODAUCMetric
+import sklearn.metrics
+
+from ood_metrics import get_ood_score
+
+class OODMetricsMultiHostTest(parameterized.TestCase):
+
+ def setUp(self):
+ super(OODMetricsMultiHostTest, self).setUp()
+
+ @parameterized.parameters([(0, 0.0), (1, 0.01), (2, 0.5), (3, 0.99), (4, 1)])
+ def test_ComputeAUCMetric(self, seed, masked_fraction):
+ """Test computation of AUC metric."""
+ np.random.seed(seed)
+
+ from_logits = False # when set to True applies sigmoid to logits.
+ num_thresholds = 100
+
+ # Create test data:
+ num_classes = 2
+ input_shape = [8, 1, 224, 224]
+ logits_shape = input_shape + [num_classes]
+ logits_np = np.random.rand(*logits_shape)
+
+ # Note: We include label -1, which indicates excluded pixels:
+ label = np.random.randint(0, num_classes, size=input_shape)
+ label[:4] = np.argmax(logits_np[:4], axis=-1) # Set half to correct.
+
+ batch_np = {
+ 'logits': logits_np,
+ 'label':
+ label,
+ 'batch_mask':
+ (np.random.rand(*input_shape) > masked_fraction) & (label != -1),
+ }
+ batch = {
+ 'logits': jnp.array(logits_np),
+ 'label': jnp.array(batch_np['label']),
+ 'batch_mask': jnp.array(batch_np['batch_mask']),
+ }
+
+ fake_batches_replicated = jax_utils.replicate([batch])
+
+ auc_roc = ComputeAUCMetric(curve='ROC', num_thresholds=num_thresholds, from_logits=from_logits)
+
+ for fake_batch in fake_batches_replicated:
+ if from_logits:
+ pred = jnp.max(fake_batch['logits'], axis=-1)
+ else:
+ pred = jnp.argmax(fake_batch['logits'], axis=-1)
+ auc_roc.calculate_and_update_scores(logits=pred,
+ label=fake_batch['label'],
+ sample_weight=fake_batch['batch_mask'],
+ )
+
+ auc_result = auc_roc.gather_metrics().numpy()
+
+ # Numpy result:
+ if np.all(batch_np['batch_mask'] == 0):
+ auc_numpy = 0
+ else:
+ labels_negative_ignored = np.maximum(batch_np['label'], 0)
+ y_pred = np.argmax(logits_np, axis=-1)
+ auc_numpy = sklearn.metrics.roc_auc_score(labels_negative_ignored.ravel(),
+ y_pred.ravel(),
+ sample_weight=batch_np['batch_mask'].ravel())
+
+ self.assertAlmostEqual(auc_result, auc_numpy, places=2)
+
+
+ @parameterized.parameters([(0, 0.0), (1, 0.01), (2, 0.5), (3, 0.99), (4, 1)])
+ def test_ComputeOODAUCMetric(self, seed, masked_fraction):
+ """Test computation of OOD scored AUC metric."""
+ np.random.seed(seed)
+ num_thresholds = 1000
+
+ ood_kwargs = {
+ 'method_name': 'mlogit',
+ }
+ # Create test data:
+ num_classes = 2
+ input_shape = [8, 1, 224, 224]
+ logits_shape = input_shape + [num_classes]
+ logits_np = np.random.rand(*logits_shape)
+
+ # Note: We include label -1, which indicates excluded pixels:
+ label = np.random.randint(0, num_classes, size=input_shape)
+ label[:4] = np.argmax(logits_np[:4], axis=-1) # Set half to correct.
+
+ batch_np = {
+ 'logits': logits_np,
+ 'label':
+ label,
+ 'batch_mask':
+ (np.random.rand(*input_shape) > masked_fraction) & (label != -1),
+ }
+ batch = {
+ 'logits': jnp.array(logits_np),
+ 'label': jnp.array(batch_np['label']),
+ 'batch_mask': jnp.array(batch_np['batch_mask']),
+ }
+
+ fake_batches_replicated = jax_utils.replicate([batch])
+
+ auc_roc = ComputeOODAUCMetric(curve='ROC', num_thresholds=num_thresholds)
+
+ for fake_batch in fake_batches_replicated:
+ pred = fake_batch['logits']
+ ood_label = 1 - fake_batch['label']
+
+ auc_roc.calculate_and_update_scores(logits=pred,
+ label=ood_label,
+ sample_weight=fake_batch['batch_mask'],
+ **ood_kwargs,
+ )
+ auc_result = auc_roc.gather_metrics().numpy()
+
+ # Numpy result:
+ if np.all(batch_np['batch_mask'] == 0):
+ auc_numpy = 0
+ else:
+ labels_negative_ignored = np.maximum(batch_np['label'], 0)
+ ood_label_np = 1 - labels_negative_ignored
+ ood_score = get_ood_score(logits_np, **ood_kwargs)
+ auc_numpy = sklearn.metrics.roc_auc_score(ood_label_np.ravel(),
+ ood_score.ravel(),
+ sample_weight=batch_np['batch_mask'].ravel())
+
+ self.assertAlmostEqual(auc_result, auc_numpy, places=1)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/experimental/robust_segvit/ood_metrics.py b/experimental/robust_segvit/ood_metrics.py
index 36e3b2ce5..2f273b02c 100644
--- a/experimental/robust_segvit/ood_metrics.py
+++ b/experimental/robust_segvit/ood_metrics.py
@@ -81,7 +81,7 @@ def preprocess_outlier(outlier):
def get_ood_score(
logits: jnp.ndarray,
- method_name: str = 'msp',
+ method_name: str = 'nmlogit',
num_top_k: int = 5,
) -> Dict[str, Any]:
"""Get OOD score."""
@@ -97,6 +97,9 @@ def get_ood_score(
elif method_name == 'mlogit':
max_logits = jnp.max(logits, -1)
ood_score = 1 - max_logits
+ elif method_name == 'nmlogit':
+ max_logits = jnp.max(logits, -1)
+ ood_score = - 1 * max_logits
elif method_name == 'sum_topklogit':
ood_score = jax.lax.top_k(logits, num_top_k)[0].sum(-1)
elif method_name == '1-sum_topklogit':
@@ -108,6 +111,31 @@ def get_ood_score(
return ood_score
+def get_score(
+ logits: jnp.ndarray,
+ method_name: str = 'mlogit',
+ num_top_k: int = 5,
+ ) -> Dict[str, Any]:
+ """Get OOD score."""
+
+ if method_name == 'msp':
+ probs = jax.nn.softmax(logits, -1)
+ ood_score = jnp.max(probs, -1)
+ elif method_name == 'entropy':
+ probs = jax.nn.softmax(logits, -1)
+ entropy = -jnp.sum(probs * jnp.log(probs), axis=-1)
+ ood_score = entropy
+ elif method_name == 'mlogit':
+ ood_score = jnp.max(logits, -1)
+ elif method_name == 'sum_topklogit':
+ ood_score = jax.lax.top_k(logits, num_top_k)[0].sum(-1)
+ else:
+ raise NotImplementedError(
+ f'Missing method {method_name} to calculate OOD score.')
+ return ood_score
+
+
+
def get_ood_metrics(
logits: jnp.ndarray,
ood_mask: jnp.ndarray,
@@ -148,11 +176,9 @@ def get_ood_metrics(
# the weights per entry are 1 if it should be included during computation
# and 0 otherwise.
- # the masked array makes any entry with value 1 as invalid.
- y_true_masked = np.ma.masked_array(y_true, mask=1-weights)
- ood_score_masked = np.ma.masked_array(ood_score, mask=1-weights)
+ y_true_masked = y_true[weights == 1]
+ ood_score_masked = ood_score[weights == 1]
- metrics = compute_ood_metrics(y_true_masked.flatten(),
- ood_score_masked.flatten())
+ metrics = compute_ood_metrics(y_true_masked.flatten(), ood_score_masked.flatten())
return metrics
diff --git a/experimental/robust_segvit/pretrainer_utils.py b/experimental/robust_segvit/pretrainer_utils.py
index a5c4ffbcc..a3b1d4d76 100644
--- a/experimental/robust_segvit/pretrainer_utils.py
+++ b/experimental/robust_segvit/pretrainer_utils.py
@@ -22,6 +22,7 @@
import numpy as np
from scenic.train_lib_deprecated import train_utils
from tensorflow.io import gfile
+from flax.training import checkpoints
def load_bb_config(
@@ -197,6 +198,59 @@ def convert_torch_to_jax_checkpoint(
optimizer={"target": restored_params},)
# pytype: enable=wrong-arg-types
+ return restored_train_state
+
+
+def convert_vision_transformer_to_scenic(
+ checkpoint_path: str,
+ convert_to_linen: bool = True) -> train_utils.TrainState:
+ """Converts a vision_transformer checkpoint to an scenic train state.
+
+ The model weights come from https://github.com/google-research/vision_transformer.
+
+ Original code: convert_big_vision_to_scenic_checkpoint
+ from https://github.com/google-research/scenic/
+
+ Args:
+ checkpoint_path: Path to checkpoint.
+ convert_to_linen: Whether to convert to Linen format.
+
+ Returns:
+ restored_train_state: Scenic train state with model weights, global step
+ and accumulated training time.
+ """
+
+ def unflatten_dict(flattened: Dict[str, Any],
+ separator: str = '/',
+ leaf_idx: int = -1) -> Dict[str, Any]:
+ unflattened = {}
+ for k, v in flattened.items():
+ subtree = unflattened
+ if leaf_idx != 0:
+ path = k.split(separator)[:leaf_idx]
+ else:
+ path = k.split(separator)
+ for k2 in path[:-1]:
+ if k2 not in subtree:
+ subtree[k2] = {}
+ subtree = subtree[k2]
+ subtree[path[-1]] = v
+ return unflattened
+
+ logging.info('Loading vision_transformer checkpoint from %s', checkpoint_path)
+ checkpoint_data = np.load(gfile.GFile(checkpoint_path, 'rb'))
+ restored_params = unflatten_dict(checkpoint_data, separator='/', leaf_idx=0)
+
+ if convert_to_linen:
+ restored_params = checkpoints.convert_pre_linen(restored_params)
+ restored_params = dict(restored_params)
+
+ train_state = train_utils.TrainState()
+ # pytype: disable=wrong-arg-types
+ restored_train_state = train_state.replace( # pytype: disable=attribute-error
+ optimizer={"target": restored_params},)
+ # pytype: enable=wrong-arg-types
+
# free memory
del restored_params
- return restored_train_state
+ return restored_train_state
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_ade20k_ind_be_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_be_eval.yaml
new file mode 100755
index 000000000..90863139c
--- /dev/null
+++ b/experimental/robust_segvit/run_ade20k_ind_be_eval.yaml
@@ -0,0 +1,39 @@
+name: ade20k_ind_be_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/ade20k_ind/be_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/ade20k_ind/be_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_ade20k_ind_deterministic_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_deterministic_eval.yaml
new file mode 100755
index 000000000..01347a6a9
--- /dev/null
+++ b/experimental/robust_segvit/run_ade20k_ind_deterministic_eval.yaml
@@ -0,0 +1,39 @@
+name: ade20k_ind_deterministic_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: true
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/ade20k_ind/deterministic_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/ade20k_ind/deterministic_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_ade20k_ind_gp_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_gp_eval.yaml
new file mode 100755
index 000000000..0612dcdf1
--- /dev/null
+++ b/experimental/robust_segvit/run_ade20k_ind_gp_eval.yaml
@@ -0,0 +1,40 @@
+name: ade20k_ind_gp_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/ade20k_ind/gp_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/ade20k_ind/gp_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_ade20k_ind_het_eval.yaml b/experimental/robust_segvit/run_ade20k_ind_het_eval.yaml
new file mode 100755
index 000000000..0ed843f4b
--- /dev/null
+++ b/experimental/robust_segvit/run_ade20k_ind_het_eval.yaml
@@ -0,0 +1,39 @@
+name: ade20k_ind_het_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/ade20k_ind/het_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/ade20k_ind/het_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_cityscapes_be_eval.yaml b/experimental/robust_segvit/run_cityscapes_be_eval.yaml
new file mode 100755
index 000000000..b0afa8ebc
--- /dev/null
+++ b/experimental/robust_segvit/run_cityscapes_be_eval.yaml
@@ -0,0 +1,40 @@
+name: cityscapes_be_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/be_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/cityscapes/be_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_cityscapes_deterministic_eval.yaml b/experimental/robust_segvit/run_cityscapes_deterministic_eval.yaml
new file mode 100755
index 000000000..b2eced3cc
--- /dev/null
+++ b/experimental/robust_segvit/run_cityscapes_deterministic_eval.yaml
@@ -0,0 +1,39 @@
+name: cityscapes_deterministic_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/deterministic_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/cityscapes/deterministic_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_cityscapes_gp_eval.yaml b/experimental/robust_segvit/run_cityscapes_gp_eval.yaml
new file mode 100755
index 000000000..ff44994d1
--- /dev/null
+++ b/experimental/robust_segvit/run_cityscapes_gp_eval.yaml
@@ -0,0 +1,40 @@
+name: cityscapes_gp_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/gp_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/cityscapes/gp_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_cityscapes_het_eval.yaml b/experimental/robust_segvit/run_cityscapes_het_eval.yaml
new file mode 100755
index 000000000..671486f60
--- /dev/null
+++ b/experimental/robust_segvit/run_cityscapes_het_eval.yaml
@@ -0,0 +1,40 @@
+name: cityscapes_het_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/het_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/cityscapes/het_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_deterministic_cityscapes.yaml b/experimental/robust_segvit/run_deterministic_cityscapes.yaml
new file mode 100755
index 000000000..a9bb2eabd
--- /dev/null
+++ b/experimental/robust_segvit/run_deterministic_cityscapes.yaml
@@ -0,0 +1,34 @@
+name: deterministic_cityscapes
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/deterministic.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/cityscapes/deterministic"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_deterministic_local.sh b/experimental/robust_segvit/run_deterministic_local.sh
new file mode 100755
index 000000000..19c7a2d6a
--- /dev/null
+++ b/experimental/robust_segvit/run_deterministic_local.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# train toy model using wandb
+#wandb sweep run_toy_mac.yaml
+# before make sure we can run code vanilla version:
+
+DATASET='ade20k_ind' # or cityscapes
+DATASET='street_hazards'
+DATASET='cityscapes'
+
+base_output_dir="gs://ub-ekb/segmenter/${DATASET}/deterministic"
+
+# Debug on Mac OS X platform
+use_gpu=False
+
+if [ "$(uname)" = "Darwin" ] ; then
+tpu=False
+num_cores=1
+batch_size=5
+elif [ "$(uname)" = "Linux" ]; then
+tpu='local'
+num_cores=8
+batch_size=8
+fi
+
+use_wandb=True
+
+config_file="configs/${DATASET}/deterministic.py:runlocal"
+run_name="local"
+output_dir="${base_output_dir}/${run_name}"
+python deterministic.py \
+--output_dir=${output_dir} \
+--num_cores=$num_cores \
+--use_gpu=$use_gpu \
+--config=${config_file} \
+--config.batch_size=${batch_size} \
+--config.use_wandb=${use_wandb} \
+--tpu=${tpu} \
diff --git a/experimental/robust_segvit/run_deterministic_street_hazards.yaml b/experimental/robust_segvit/run_deterministic_street_hazards.yaml
new file mode 100755
index 000000000..182f288e0
--- /dev/null
+++ b/experimental/robust_segvit/run_deterministic_street_hazards.yaml
@@ -0,0 +1,38 @@
+name: deterministic_street_hazards_hparam
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.lr_configs.base_learning_rate:
+ values: [0.0001, 0.00001, 0.0003]
+ config.num_training_epochs:
+ values: [50, 100]
+
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/street_hazards/deterministic.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/street_hazards/deterministic"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_eval.sh b/experimental/robust_segvit/run_eval.sh
new file mode 100644
index 000000000..f2f120687
--- /dev/null
+++ b/experimental/robust_segvit/run_eval.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# call eval model using wandb
+
+# Debug on Mac OS X platform
+use_gpu=False
+if [ "$(uname)" = "Darwin" ] ; then
+tpu=False
+num_cores=1
+batch_size=5
+elif [ "$(uname)" = "Linux" ]; then
+tpu='local'
+num_cores=8
+batch_size=8
+fi
+
+# default config for eval
+eval_covariate_shift=False
+method_name='msp'
+use_wandb=True
+
+for dataset in "ade20k_ind" "street_hazards" "cityscapes"
+do
+for model in "gp" "be" "deterministic" "het"
+do
+base_output_dir="gs://ub-ekb/segmenter/${dataset}"
+config_file="configs/${dataset}/${model}_eval.py"
+run_name="${model}_eval"
+output_dir="${base_output_dir}/${run_name}"
+python deterministic.py \
+--output_dir=${output_dir} \
+--num_cores=$num_cores \
+--use_gpu=$use_gpu \
+--config=${config_file} \
+--config.batch_size=${batch_size} \
+--config.eval_robustness_configs.method_name=${method_name} \
+--config.eval_covariate_shift=${eval_covariate_shift} \
+--config.use_wandb=${use_wandb} \
+--tpu=${tpu} \
+
+done
+done
diff --git a/experimental/robust_segvit/run_eval_local.sh b/experimental/robust_segvit/run_eval_local.sh
new file mode 100755
index 000000000..16fed5f45
--- /dev/null
+++ b/experimental/robust_segvit/run_eval_local.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+# evaluate model using wandb
+#wandb sweep run_toy_mac.yaml
+# before make sure we can run code vanilla version:
+
+DATASET='ade20k_ind' # or cityscapes
+DATASET='street_hazards'
+
+# Parameters
+DATASET='cityscapes'
+model='deterministic'
+
+base_output_dir="gs://ub-ekb/segmenter/${DATASET}/${model}_eval"
+
+# Debug on Mac OS X platform
+use_gpu=False
+if [ "$(uname)" = "Darwin" ] ; then
+tpu=False
+num_cores=1
+batch_size=1
+elif [ "$(uname)" = "Linux" ]; then
+tpu='local'
+num_cores=8
+batch_size=8
+fi
+
+config_file="configs/${DATASET}/${model}_eval.py:runlocal"
+run_name="${model}_eval"
+output_dir="${base_output_dir}/${run_name}"
+python deterministic.py \
+--output_dir=${output_dir} \
+--num_cores=$num_cores \
+--use_gpu=$use_gpu \
+--config=${config_file} \
+--config.batch_size=${batch_size} \
+--tpu=${tpu} \
diff --git a/experimental/robust_segvit/run_street_hazards_be.yaml b/experimental/robust_segvit/run_street_hazards_be.yaml
new file mode 100755
index 000000000..98742c2c9
--- /dev/null
+++ b/experimental/robust_segvit/run_street_hazards_be.yaml
@@ -0,0 +1,37 @@
+name: be_street_hazards_hparam
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 24
+ config.model.backbone.random_sign_init:
+ values: [-0.5, -0.25, 0.25, 0.5]
+ config.fast_weight_lr_multiplier:
+ values: [0.5, 1.0, 2.0]
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/street_hazards/be.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/street_hazards/be"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_street_hazards_deterministic_eval.yaml b/experimental/robust_segvit/run_street_hazards_deterministic_eval.yaml
new file mode 100755
index 000000000..78e7d7ee0
--- /dev/null
+++ b/experimental/robust_segvit/run_street_hazards_deterministic_eval.yaml
@@ -0,0 +1,39 @@
+name: street_hazards_deterministic_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 8
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/street_hazards/deterministic_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/street_hazards/deterministic_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_street_hazards_gp.yaml b/experimental/robust_segvit/run_street_hazards_gp.yaml
new file mode 100755
index 000000000..f6bbdf6b9
--- /dev/null
+++ b/experimental/robust_segvit/run_street_hazards_gp.yaml
@@ -0,0 +1,35 @@
+name: gp_street_hazards_hparam
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 24
+ config.model.decoder.mean_field_factor:
+ values: [1, 2, 5, 6, 10]
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/street_hazards/gp.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/street_hazards/gp"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_street_hazards_gp_eval.yaml b/experimental/robust_segvit/run_street_hazards_gp_eval.yaml
new file mode 100755
index 000000000..988f79c2b
--- /dev/null
+++ b/experimental/robust_segvit/run_street_hazards_gp_eval.yaml
@@ -0,0 +1,39 @@
+name: street_hazards_gp_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 8
+ config.eval_configs.store_logits:
+ value: false
+ config.eval_covariate_shift:
+ value: false
+ config.eval_robustness_configs.method_name:
+ value: 'msp'
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/street_hazards/gp_eval.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/street_hazards/gp_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_street_hazards_het.yaml b/experimental/robust_segvit/run_street_hazards_het.yaml
new file mode 100755
index 000000000..af4fd2799
--- /dev/null
+++ b/experimental/robust_segvit/run_street_hazards_het.yaml
@@ -0,0 +1,35 @@
+name: het_street_hazards_hparam
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 24
+ config.model.decoder.temperature:
+ values: [0.15, 0.3, 1, 1.5, 2.0]
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/street_hazards/het.py"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/street_hazards/het"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_toy_eval.sh b/experimental/robust_segvit/run_toy_eval.sh
new file mode 100755
index 000000000..67426937b
--- /dev/null
+++ b/experimental/robust_segvit/run_toy_eval.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# train toy model using wandb
+#wandb sweep run_toy_mac.yaml
+# before make sure we can run code vanilla version:
+
+DATASET='ade20k_ind' # or cityscapes
+DATASET='street_hazards'
+DATASET='cityscapes'
+
+base_output_dir="gs://ub-ekb/segmenter/${DATASET}/toy_model"
+
+# Debug on Mac OS X platform
+use_gpu=False
+if [ "$(uname)" = "Darwin" ] ; then
+tpu=False
+num_cores=1
+batch_size=5
+elif [ "$(uname)" = "Linux" ]; then
+tpu='local'
+num_cores=8
+batch_size=8
+fi
+
+use_wandb=True
+config_file="configs/${DATASET}/toy_model_eval.py:runlocal"
+run_name="toy_model_eval"
+output_dir="${base_output_dir}/${run_name}"
+python deterministic.py \
+--output_dir=${output_dir} \
+--num_cores=$num_cores \
+--use_gpu=$use_gpu \
+--config=${config_file} \
+--config.batch_size=${batch_size} \
+--config.use_wandb=${use_wandb} \
+--tpu=${tpu} \
diff --git a/experimental/robust_segvit/run_toy_mac.sh b/experimental/robust_segvit/run_toy_mac.sh
new file mode 100755
index 000000000..adc4c613c
--- /dev/null
+++ b/experimental/robust_segvit/run_toy_mac.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+# ----------------------------------------------------
+# train toy model on a DATASET:
+# ----------------------------------------------------
+
+# to train toy model and track performance using wandb:
+# wandb sweep run_toy_mac.yaml
+
+DATASET='ade20k_ind'
+DATASET='street_hazards'
+DATASET='cityscapes'
+
+# ----------------------------------------------------
+# Set directory where outputs should be installed:
+# ----------------------------------------------------
+# can write results directly to gcp bucket
+# base_output_dir="gs://ub-ekb/segmenter/${DATASET}/toy_model"
+dt=$(date +"%Y-%m-%d-%H-%M-%S")
+
+base_output_dir="results/${DATASET}"
+
+run_name="toy_model"
+output_dir="${base_output_dir}/${run_name}/${dt}"
+# ----------------------------------------------------
+# Set device configuration for Mac OS X platform
+# or TPU v2-8/v3-8 frameworks.
+# ----------------------------------------------------
+use_gpu=False
+if [ "$(uname)" = "Darwin" ] ; then
+tpu=False
+num_cores=1
+batch_size=5
+elif [ "$(uname)" = "Linux" ]; then
+tpu='local'
+num_cores=8
+batch_size=8
+fi
+
+# ----------------------------------------------------
+# Set configuration file
+# ----------------------------------------------------
+config_file="configs/${DATASET}/toy_model.py:runlocal"
+use_wandb=True
+eval_covariate_shift=False
+
+# ----------------------------------------------------
+# Call model trainer.
+# ----------------------------------------------------
+python deterministic.py \
+--output_dir=${output_dir} \
+--num_cores=$num_cores \
+--use_gpu=$use_gpu \
+--config=${config_file} \
+--config.batch_size=${batch_size} \
+--config.eval_covariate_shift=${eval_covariate_shift} \
+--config.use_wandb=${use_wandb} \
+--tpu=${tpu} \
diff --git a/experimental/robust_segvit/run_toy_mac.yaml b/experimental/robust_segvit/run_toy_mac.yaml
new file mode 100755
index 000000000..dba4db0bf
--- /dev/null
+++ b/experimental/robust_segvit/run_toy_mac.yaml
@@ -0,0 +1,34 @@
+name: toy_model
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 8
+
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/toy_model.py:runlocal"
+ - "--output_dir"
+ - "gs://ub-ekb/segmenter/cityscapes/toy_model"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/run_train_seed.sh b/experimental/robust_segvit/run_train_seed.sh
new file mode 100644
index 000000000..0625dcc43
--- /dev/null
+++ b/experimental/robust_segvit/run_train_seed.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# call eval model using wandb
+
+# Debug on Mac OS X platform
+use_gpu=False
+if [ "$(uname)" = "Darwin" ] ; then
+tpu=False
+num_cores=1
+batch_size=5
+elif [ "$(uname)" = "Linux" ]; then
+tpu='local'
+num_cores=8
+batch_size=16
+fi
+
+# default config for eval
+use_wandb=True
+
+for dataset in "cityscapes" #"ade20k_ind" "street_hazards"
+do
+for model in "deterministic" "gp" "het" "be"
+do
+for rng_seed in 1
+do
+base_output_dir="gs://ub-ekb/segmenter/${dataset}"
+config_file="configs/${dataset}/${model}.py"
+run_name="${model}_eval"
+output_dir="${base_output_dir}/${run_name}"
+python deterministic.py \
+--output_dir=${output_dir} \
+--num_cores=$num_cores \
+--use_gpu=$use_gpu \
+--config=${config_file} \
+--config.batch_size=${batch_size} \
+--config.use_wandb=${use_wandb} \
+--config.rng_seed=${rng_seed} \
+--tpu=${tpu} \
+
+done
+done
+done
diff --git a/experimental/robust_segvit/store_cityscapes_be_eval.yaml b/experimental/robust_segvit/store_cityscapes_be_eval.yaml
new file mode 100755
index 000000000..16d6baa9f
--- /dev/null
+++ b/experimental/robust_segvit/store_cityscapes_be_eval.yaml
@@ -0,0 +1,37 @@
+name: cityscapes_be_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: true
+ config.eval_covariate_shift:
+ value: false
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/be_eval.py"
+ - "--output_dir"
+ - "cityscapes/be_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/store_cityscapes_deterministic_eval.yaml b/experimental/robust_segvit/store_cityscapes_deterministic_eval.yaml
new file mode 100755
index 000000000..f57045b45
--- /dev/null
+++ b/experimental/robust_segvit/store_cityscapes_deterministic_eval.yaml
@@ -0,0 +1,37 @@
+name: cityscapes_deterministic_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: true
+ config.eval_covariate_shift:
+ value: false
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/deterministic_eval.py"
+ - "--output_dir"
+ - "cityscapes/deterministic_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/store_cityscapes_gp_eval.yaml b/experimental/robust_segvit/store_cityscapes_gp_eval.yaml
new file mode 100755
index 000000000..87b82c3cb
--- /dev/null
+++ b/experimental/robust_segvit/store_cityscapes_gp_eval.yaml
@@ -0,0 +1,37 @@
+name: cityscapes_gp_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: true
+ config.eval_covariate_shift:
+ value: false
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/gp_eval.py"
+ - "--output_dir"
+ - "cityscapes/gp_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/store_cityscapes_het_eval.yaml b/experimental/robust_segvit/store_cityscapes_het_eval.yaml
new file mode 100755
index 000000000..ba6eadaff
--- /dev/null
+++ b/experimental/robust_segvit/store_cityscapes_het_eval.yaml
@@ -0,0 +1,37 @@
+name: cityscapes_het_eval
+program: deterministic.py
+method: grid
+project: rdl-debug
+entity: ekellbuch
+
+metric:
+ name: valid_loss
+ goal: minimize
+parameters:
+ config.use_wandb:
+ value: true
+ config.wandb_project :
+ value: ${{project}}
+ config.wandb_entity :
+ value: ${{entity}}
+ config.batch_size:
+ value: 16
+ config.eval_configs.store_logits:
+ value: true
+ config.eval_covariate_shift:
+ value: false
+
+
+command:
+ - ${env}
+ - python
+ - ${program}
+ - "--config"
+ - "configs/cityscapes/het_eval.py"
+ - "--output_dir"
+ - "cityscapes/het_eval"
+ - "--num_cores"
+ - "8"
+ - "--tpu"
+ - "local"
+ - ${args}
\ No newline at end of file
diff --git a/experimental/robust_segvit/uncertainty_metrics.py b/experimental/robust_segvit/uncertainty_metrics.py
index 031c1fecb..e22934c9f 100644
--- a/experimental/robust_segvit/uncertainty_metrics.py
+++ b/experimental/robust_segvit/uncertainty_metrics.py
@@ -15,11 +15,10 @@
"""Calculate uncertainty metrics for segmentation tasks."""
from typing import Optional, Tuple
+import jax
from jax import lax
import jax.numpy as jnp
-from scenic.model_lib.base_models.model_utils import apply_weights
-
-# TODO(kellybuchanan): reconcile cases where mask is 0.
+from scenic.model_lib.layers import nn_ops
def calculate_num_patches_binary_maps(
@@ -29,7 +28,7 @@ def calculate_num_patches_binary_maps(
Args:
binary_acc_map : binary accuracy map
- binary_unc_map : binary uncertainty map
+ binary_unc_map : binary uncertainty map (1=certain, 0=uncertain)
Returns:
metrics to calculate uncertainty scores
@@ -37,24 +36,24 @@ def calculate_num_patches_binary_maps(
# number of patches that are accurate and certain
n_ac = jnp.sum(
jnp.logical_and(
- jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 0)),
+ jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 1)),
axis=(-1, -2))
# number of patches that are inaccurate and certain
n_ic = jnp.sum(
jnp.logical_and(
- jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 0)),
+ jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 1)),
axis=(-1, -2))
# number of patches that are inaccurate and uncertain
n_iu = jnp.sum(
jnp.logical_and(
- jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 1)),
+ jnp.equal(binary_acc_map, 0), jnp.equal(binary_unc_map, 0)),
axis=(-1, -2))
# number of patches that are accurate and uncertain
n_au = jnp.sum(
jnp.logical_and(
- jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 1)),
+ jnp.equal(binary_acc_map, 1), jnp.equal(binary_unc_map, 0)),
axis=(-1, -2))
unc_confusion_matrix = jnp.stack((n_ac, n_ic, n_iu, n_au), axis=-1)
@@ -86,17 +85,18 @@ def get_pavpu(unc_confusion_matrix):
def get_uncertainty_confusion_matrix(
+ *,
logits: jnp.ndarray,
labels: jnp.ndarray,
+ uncertainty_measure: str = 'softmax',
+ accuracy_measure : str = 'predictive_accuracy',
weights: Optional[jnp.ndarray] = None,
accuracy_th: Optional[float] = 0.5,
uncertainty_th: Optional[float] = 0.5,
- window_size: Optional[int] = 2
+ window_size: Optional[int] = 2,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Calculate counts of patches accurate/inacurate and certain/uncertain.
- TODO(kellybuchanan): include weights for entropy calculation.
-
Args:
logits: predicted logits
labels: true labels
@@ -118,29 +118,43 @@ def get_uncertainty_confusion_matrix(
preds = jnp.argmax(logits, axis=-1)
# calculate binary accuracy map
- correct = jnp.equal(preds, targets)
+ correct = jnp.equal(preds, targets).astype(jnp.float32)
- # batch masking
- if weights is not None:
- correct = apply_weights(correct, weights)
+ if weights is None:
+ weights = jnp.ones(correct.shape)
- correct = correct.astype(jnp.float32)
+ weights = weights.astype(jnp.float32)
- # A given patch is accurate if its acc > accuracy_threshold
- binary_acc_map = reduce_2dmap(correct, window_size,
- accuracy_th).astype(jnp.float32)
+ if accuracy_measure == 'predictive_accuracy':
+ accuracy_map = correct
+ else:
+ raise NotImplementedError('Accuracy measure not implemented.')
- # Calculate uncertainty map
- entropy = get_entropy_from_logits(logits)
+ # A given patch is accurate if its acc > accuracy_threshold
+ binary_acc_map = reduce_2dmap_weighted(accuracy_map,
+ weights,
+ window_size=window_size,
+ threshold=accuracy_th).astype(jnp.float32)
+
+ # Calculate uncertainty map:
+ if uncertainty_measure == 'softmax':
+ uncertainty_map = jnp.max(jax.nn.softmax(logits, -1), -1)
+ elif uncertainty_measure == 'entropy':
+ uncertainty_map = get_entropy_from_logits(logits)
+ else:
+ raise NotImplementedError(f'Uncertainty measure {uncertainty_measure} not implemented.')
- # A given patch is uncertain if its uncertainty > uncertainty_th
- binary_unc_map = reduce_2dmap(entropy, window_size,
- uncertainty_th).astype(jnp.float32)
+ # A given patch is certain if its uncertainty > uncertainty_th
+ binary_unc_map = reduce_2dmap_weighted(uncertainty_map,
+ weights,
+ window_size=window_size,
+ threshold=uncertainty_th).astype(jnp.float32)
# number of patches that are accurate and certain
unc_confusion_matrix = calculate_num_patches_binary_maps(
binary_acc_map, binary_unc_map)
+ unc_confusion_matrix = unc_confusion_matrix[jnp.newaxis, ...] # Dummy batch dim.
return unc_confusion_matrix
@@ -159,14 +173,14 @@ def reduce_2dmap(
"""Given a map, apply a 2d spatial strided convolution to avg adjacent values.
Args:
- array_map: array to be split.
+ array_map: array to be split. 3-D Tensor; With shape `[batch, in_rows, in_cols].
window_size: size of window.
threshold: threshold for binarization.
Returns:
binary_map: binary map.
"""
- # Expand dimension to match filter C dimension.
+ # Expand dimension for dummy depth dimension
array_map = jnp.expand_dims(array_map, -1)
# Create a kernel
@@ -193,6 +207,45 @@ def reduce_2dmap(
return binary_map.astype(jnp.int32)
+def reduce_2dmap_weighted(
+ array_map: jnp.ndarray,
+ weights: jnp.ndarray,
+ window_size: int = 4,
+ threshold: float = 0.5,
+) -> jnp.ndarray:
+ """Given a map, apply a pooling operation to avg adjacent values.
+
+ Args:
+ array_map: array to be split. 3-D Tensor; With shape `[batch, in_rows, in_cols].
+ weights: array of weights. 3-D Tensor; With shape `[batch, in_rows, in_cols].
+ window_size: size of window.
+ threshold: threshold for binarization.
+ data_format: str; The format of the `lhs`. Must be either `'NHWC'` or `'NCHW'`.
+
+ Returns:
+ binary_map: binary map.
+ """
+ # Expand dimension for dummy feature dimension
+ array_map = jnp.expand_dims(array_map, -1)
+
+ window_shape = (window_size, window_size)
+
+ outputs = nn_ops.weighted_avg_pool(
+ array_map,
+ weights,
+ window_shape=window_shape,
+ strides=window_shape,
+ padding='VALID')
+
+ # Binarize_map according to threshold
+ binary_map = jnp.greater_equal(outputs, threshold)
+
+ # Squeeze dummy feature dimension
+ binary_map = jnp.squeeze(binary_map, -1)
+
+ return binary_map.astype(jnp.int32)
+
+
class SegmentationUncertaintyMetrics(object):
"""Calculate uncertainty scores for image segmentation task."""
diff --git a/experimental/robust_segvit/uncertainty_metrics_test.py b/experimental/robust_segvit/uncertainty_metrics_test.py
deleted file mode 100644
index c3a4c2a78..000000000
--- a/experimental/robust_segvit/uncertainty_metrics_test.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# coding=utf-8
-# Copyright 2022 The Uncertainty Baselines Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for uncertainty_metrics."""
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax.numpy as jnp
-from uncertainty_metrics import reduce_2dmap # local file import from experimental.robust_segvit
-from uncertainty_metrics import SegmentationUncertaintyMetrics # local file import from experimental.robust_segvit
-
-
-class UncertaintyMetricsTest(parameterized.TestCase):
-
- def setUp(self):
- super(UncertaintyMetricsTest, self).setUp()
- self.targets = jnp.asarray([[[1, 2, 5, 7], [6, 4, 3, 3], [10, 9, 5, 0],
- [8, 6, 4, 4]]])
-
- self.preds = jnp.asarray([[[1, 2, 4, 7], [5, 6, 3, 3], [10, 9, 4, 0],
- [8, 7, 3, 4]]])
-
- self.unc_map = jnp.asarray([[[0.1, 0.3, 0.6, 0.3], [0.7, 0.6, 0.2, 0.1],
- [0.2, 0.4, 0.5, 0.3], [0.1, 0.7, 0.6, 0.2]]])
-
- # create logit map from unc_map by mapping entropy vals (0.1,0.7)
- # to a feasible range of logit vals:(4.1, 6.2)
- self.logit_map = 4.1 + (0.7 - self.unc_map) * (6.2 - 4.1) / (0.7 - 0.1)
-
- self.window_size = 2
- self.accuracy_th = 0.5
- self.uncertainty_th = 0.4
-
- # true values
- self.true_binary_acc_map = jnp.asarray([[[0, 1], [1, 0]]])
- self.true_binary_unc_map = jnp.asarray([[[1, 0], [0, 0]]])
- self.true_p_accurate_certain = jnp.asarray([0.67])
- self.true_p_uncertain_innacurate = jnp.asarray([0.5])
- self.true_pavpu = jnp.asarray([0.75])
-
- # construct logits passed as input from unc_map
- self.num_classes = 11
- self.img_size = 4
-
- true_mask = jnp.arange(self.img_size * self.img_size
- ) * self.num_classes + self.preds.flatten()
- logits = jnp.zeros((self.img_size * self.img_size * self.num_classes))
- logits = logits.at[true_mask].set(self.logit_map.flatten())
- self.logits = jnp.expand_dims(
- logits.reshape((self.img_size, self.img_size, self.num_classes)), 0)
-
- def test_setup(self):
- preds_logits = jnp.argmax(self.logits, -1)
- self.assertTrue(jnp.array_equal(self.preds, preds_logits))
-
- def test_calculate_pacc_cert(self):
- segment_unc = SegmentationUncertaintyMetrics(
- logits=self.logits,
- labels=self.targets,
- window_size=self.window_size,
- accuracy_th=self.accuracy_th,
- uncertainty_th=self.uncertainty_th)
-
- self.assertEqual(self.true_pavpu, segment_unc.pavpu)
- self.assertAlmostEqual(self.true_p_accurate_certain, segment_unc.pacc_cert,
- 2)
- self.assertAlmostEqual(self.true_p_uncertain_innacurate,
- segment_unc.puncert_inacc, 2)
-
- @parameterized.parameters((1), (2), (3))
- def test_reduce_2dmap(self, batch_size):
- array_map = jnp.repeat(jnp.ones((1, 4, 4)), batch_size, axis=0)
- true_binary_map = jnp.repeat(jnp.ones((1, 2, 2)), batch_size, axis=0)
- binary_map = reduce_2dmap(array_map, self.window_size, self.accuracy_th)
-
- self.assertTrue(jnp.array_equal(true_binary_map, binary_map))
-
-
-if __name__ == '__main__':
- absltest.main()