From a740487ffe7fbd8c84050d620ff6d625042897cd Mon Sep 17 00:00:00 2001 From: Uncertainty Baselines Team Date: Mon, 5 Dec 2022 23:46:11 -0800 Subject: [PATCH] Allows calculation of Tracin values PiperOrigin-RevId: 493218263 --- .../waterbirds_resnet_tracin_config.py | 64 ++++++++++ experimental/shoshin/generate_tracin_table.py | 111 ++++++++++++++++++ 2 files changed, 175 insertions(+) create mode 100644 experimental/shoshin/configs/waterbirds_resnet_tracin_config.py create mode 100644 experimental/shoshin/generate_tracin_table.py diff --git a/experimental/shoshin/configs/waterbirds_resnet_tracin_config.py b/experimental/shoshin/configs/waterbirds_resnet_tracin_config.py new file mode 100644 index 000000000..0c7f3d122 --- /dev/null +++ b/experimental/shoshin/configs/waterbirds_resnet_tracin_config.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration file for experiment with Waterbirds data and ResNet model.""" + +import ml_collections +from configs import base_config # local file import from experimental.shoshin + + +def get_signal_config(): + """Get training config.""" + config = ml_collections.ConfigDict() + config.checkpoint_selection = 'first' + config.checkpoint_list = [ + 'epoch-01-val_auc-0.76.ckpt', + ] + config.checkpoint_name = '' + config.checkpoint_number = 5 + config.included_layers = -2 + return config + + +def get_config() -> ml_collections.ConfigDict: + """Get mlp config.""" + config = base_config.get_config() + + # Consider landbirds on water and waterbirds on land as subgroups. + config.data.subgroup_ids = () # ('0_1', '1_0') + config.data.subgroup_proportions = () # (0.04, 0.012) + config.data.initial_sample_proportion = 1 + + config.active_sampling.num_samples_per_round = 500 + config.num_rounds = 4 + + data = config.data + data.name = 'waterbirds10k' + data.num_classes = 2 + + model = config.model + model.name = 'resnet' + model.dropout_rate = 0.2 + + config.output_dir = '' + config.generate_individual_table = True + config.round_idx = 0 + + # To ensure that the last layers actually predict the outcome and not the bias + config.train_bias = False + + config.signal = get_signal_config() + + return config diff --git a/experimental/shoshin/generate_tracin_table.py b/experimental/shoshin/generate_tracin_table.py new file mode 100644 index 000000000..50d619721 --- /dev/null +++ b/experimental/shoshin/generate_tracin_table.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2022 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Binary executable for generating tracin table. + +This file serves as a binary to calculate tracin values and create a lookup table +that maps from example ID to tracin label. + +Usage: +# pylint: disable=line-too-long + + ml_python3 third_party/py/uncertainty_baselines/experimental/shoshin/generate_tracin_table.py \ + --adhoc_import_modules=uncertainty_baselines \ + -- \ + --xm_runlocal \ + --logtostderr \ + --config=third_party/py/uncertainty_baselines/experimental/shoshin/configs/waterbirds_resnet_tracin_config.py + +# pylint: enable=line-too-long + +Note: In output_dir, models trained on different splits of data must already +exist and be present in directory. +""" + +import os + +from absl import app +from absl import flags +from absl import logging +from ml_collections import config_flags +import data # local file import from experimental.shoshin +import generate_bias_table_lib # local file import from experimental.shoshin +import models # local file import from experimental.shoshin +import sampling_policies # local file import from experimental.shoshin +from configs import base_config # local file import from experimental.shoshin + + +FLAGS = flags.FLAGS +config_flags.DEFINE_config_file('config') + + +def main(_) -> None: + + config = FLAGS.config + base_config.check_flags(config) + ckpt_dir = os.path.join(config.output_dir, + generate_bias_table_lib.CHECKPOINT_SUBDIR) + model_params = models.ModelTrainingParameters( + model_name=config.model.name, + train_bias=config.train_bias, + num_classes=config.data.num_classes, + num_subgroups=0, + num_epochs=config.training.num_epochs, + learning_rate=config.optimizer.learning_rate, + hidden_sizes=config.model.hidden_sizes, + ) + + dataset_builder = data.get_dataset(config.data.name) + if config.generate_individual_table: + if config.round_idx == 0: + dataloader = dataset_builder(config.data.num_splits, + config.data.initial_sample_proportion, + config.data.subgroup_ids, + config.data.subgroup_proportions,) + else: + dataloader = dataset_builder(config.data.num_splits, 1, + config.data.subgroup_ids, + config.data.subgroup_proportions,) + # Filter each split to only have examples from example_ids_table + dataloader.train_splits = [ + dataloader.train_ds.filter( + generate_bias_table_lib.filter_ids_fn(ids_tab)) for + ids_tab in sampling_policies.convert_ids_to_table(config.ids_dir)] + dataloader = data.apply_batch(dataloader, config.data.batch_size) + model_params.num_subgroups = dataloader.num_subgroups + model_checkpoints = generate_bias_table_lib.load_model_checkpoints( + ckpt_dir, model_params, config.signal.checkpoint_list, + config.signal.checkpoint_selection, config.signal.checkpoint_number, + config.signal.checkpoint_name) + + logging.info('%s checkpoints loaded', len(model_checkpoints)) + if config.signal.checkpoint_selection == 'name': + table_name = config.signal.checkpoint_name + else: + table_name = config.signal.checkpoint_selection + _ = generate_bias_table_lib.get_example_id_to_tracin_value_table( + dataloader=dataloader, + model_checkpoints=model_checkpoints, + included_layers=config.signal.included_layers, + save_dir=config.save_dir, + save_table=True, + table_name=table_name) + else: + # TODO(martinstrobel): Combine individual tracinvalues to a mean value + raise NotImplementedError('Not implemented yet') + + +if __name__ == '__main__': + app.run(main)