diff --git a/experimental/shoshin/configs/base_config.py b/experimental/shoshin/configs/base_config.py index 3ca453c55..885e5c7e5 100644 --- a/experimental/shoshin/configs/base_config.py +++ b/experimental/shoshin/configs/base_config.py @@ -49,13 +49,13 @@ def get_data_config(): config.num_classes = 2 config.batch_size = 64 # Number of slices into which train and val will be split. - config.num_splits = 5 + config.num_splits = 100 # Ratio of splits that will be considered out-of-distribution from each # combination, e.g. when num_splits == 5 and ood_ratio == 0.4, 2 out 5 # slices will be excluded for every combination of training data. config.ood_ratio = 0.4 - # Indices of data splits to include in training. All by default. - config.included_splits_idx = (0, 1, 2, 3, 4) + # Indices of data splits to include in training. + config.split_id = 0 # Subgroup IDs. Specify them in an experiment config. For example, for # Waterbirds, the subgroup IDs might be ('0_1', '1_0') for landbirds on water # and waterbirds on land, respectively. @@ -64,6 +64,13 @@ def get_data_config(): # Waterbirds, the subgroup proportions might be (0.05, 0.05), meaning each # subgroup will represent 5% of the dataset. config.subgroup_proportions = () + config.split_seed = 0 + config.initial_sample_seed = 0 + config.split_proportion = 1.0 + + # Leave one out training + config.loo_id = '' + config.loo_training = False # Proportion of training set to sample initially. Rest is considered the pool # for active sampling. @@ -151,6 +158,9 @@ def get_config() -> ml_collections.ConfigDict: # Round of acitve sampling being performed config.round_idx = -1 + # Keep predictions of individual models + config.keep_individual_predictions = True + # Whether to generate bias table (from stage one models) or prediction table # (from stage two models) config.generate_bias_table = True diff --git a/experimental/shoshin/data.py b/experimental/shoshin/data.py index 59b4d68d2..cb088b68b 100644 --- a/experimental/shoshin/data.py +++ b/experimental/shoshin/data.py @@ -25,8 +25,9 @@ import dataclasses import json import os -from typing import Any, Dict, Iterator, Optional, Tuple, List, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +import pandas as pd import tensorflow as tf import tensorflow_datasets as tfds @@ -71,8 +72,8 @@ def get_dataset(name: str): @dataclasses.dataclass class Dataloader: - train_splits: tf.data.Dataset # Result of tfds.load with 'split' arg. - val_splits: tf.data.Dataset # Result of tfds.load with 'split' arg. + train_splits: List[tf.data.Dataset] # Result of tfds.load with 'split' arg. + val_splits: List[tf.data.Dataset] # Result of tfds.load with 'split' arg. train_ds: tf.data.Dataset # Dataset with all the train splits combined. train_sample_ds: Optional[tf.data.Dataset] = None # Subsample of train set. eval_ds: Optional[Dict[ @@ -120,6 +121,15 @@ def get_train_ids(dataloader: Dataloader): return ids_train +def filter_ids_fn(hash_table, value=1): + # Filter dataset based on whether ids take a certain value in hash table. + + def filter_fn(feats, label, example_ids): + del feats, label + return hash_table.lookup(example_ids) == value + return filter_fn + + class CardiotoxFingerprintDataset(tfds.core.GeneratorBasedBuilder): """DatasetBuilder for cardiotoxicity fingerprint dataset.""" @@ -455,7 +465,12 @@ def get_waterbirds_dataset( initial_sample_proportion: float, subgroup_ids: List[str], subgroup_proportions: List[float], + split_proportion: Optional[float] = 0.7, + initial_sample_seed: Optional[int] = 0, + split_seed: Optional[int] = 0, is_training: Optional[bool] = True, + loo_training: Optional[bool] = False, + loo_id: Optional[str] = None, ) -> Dataloader: """Returns datasets for training, validation, and possibly test sets. @@ -466,41 +481,135 @@ def get_waterbirds_dataset( subgroup_ids: List of strings of IDs indicating subgroups. subgroup_proportions: List of floats indicating proportion that each subgroup should take in initial training dataset. + split_proportion: Relative size of each split + initial_sample_seed: Seed used for supsampling from org dataset + split_seed: Seed used to sample the splits is_training: Dataset used for evaluation (in this case as_supervised is set to True in the val/test) + loo_training: Train splits with only a single sample removed + loo_id: the id of the removed sample Returns: A tuple containing the split training data, split validation data, the combined training dataset, and a dictionary mapping evaluation dataset names to their respective combined datasets. """ - split_size_in_pct = int(100 * initial_sample_proportion / num_splits) - reduced_datset_sz = int(100 * initial_sample_proportion) + + meta_data_df = pd.read_csv(_WATERBIRDS_METADATA_DIR) + reduced_datset_sz = int( + (meta_data_df['split'] == 0).sum() * initial_sample_proportion) + builder_kwargs = { 'subgroup_ids': subgroup_ids, 'subgroup_proportions': subgroup_proportions } - val_splits = tfds.load( - 'waterbirds_dataset', - split=[ - f'validation[{k}%:{k+split_size_in_pct}%]' - for k in range(0, reduced_datset_sz, split_size_in_pct) - ], - data_dir=DATA_DIR, - builder_kwargs=builder_kwargs, - try_gcs=False, - as_supervised=is_training) - train_splits = tfds.load( + train_ds = tfds.load( 'waterbirds_dataset', - split=[ - f'train[{k}%:{k+split_size_in_pct}%]' - for k in range(0, reduced_datset_sz, split_size_in_pct) - ], + split='train', data_dir=DATA_DIR, builder_kwargs=builder_kwargs, try_gcs=False, - as_supervised=is_training) + as_supervised=True, + with_info=False) + + def create_ids_table(meta_data_df, + initial_sample_size, + initial_sample_seed, + split_proportion, + num_splits, + split_seed, + train=True) -> List[tf.lookup.StaticHashTable]: + """Creates static hash tables representing ids in each file in ids_dir for each split.""" + ids_tables = [] + meta_data_df = meta_data_df[meta_data_df['split'] == 0] # Get trainset + ids = meta_data_df['img_filename'] + subset_ids = ids.sample( + n=initial_sample_size, random_state=initial_sample_seed) + # ids_dir is populated by the sample_and_split_ids function above + for split_num in range(num_splits): + ids_i = subset_ids.sample( + frac=split_proportion, random_state=split_num + split_seed) + if not train: + ids_i = subset_ids[~subset_ids.isin(ids_i)] + if split_proportion >= 1: + # If the split is the entire subsample we use the + # remaining data as validation, only needed for loo training. + ids_i = ids[~ids.isin(subset_ids)] + keys = tf.convert_to_tensor(ids_i, dtype=tf.string) + values = tf.ones(shape=keys.shape, dtype=tf.int64) + init = tf.lookup.KeyValueTensorInitializer( + keys=keys, + values=values, + key_dtype=tf.string, + value_dtype=tf.int64) + ids_tables.append(tf.lookup.StaticHashTable(init, default_value=0)) + return ids_tables + + def create_loo_ids_table(meta_data_df, + initial_sample_size, + initial_sample_seed, + loo_id, + num_splits, + train=True) -> List[tf.lookup.StaticHashTable]: + """Creates static hash table representing ids in each file in ids_dir with all ids except loo_id.""" + ids_tables = [] + meta_data_df = meta_data_df[meta_data_df['split'] == 0] # Get trainset + ids = meta_data_df['img_filename'] + subset_ids = ids.sample( + n=initial_sample_size, random_state=initial_sample_seed) + # ids_dir is populated by the sample_and_split_ids function above + ids_i = subset_ids[subset_ids != loo_id] + if not train: + ids_i = subset_ids[~subset_ids.isin(ids_i)] + keys = tf.convert_to_tensor(ids_i, dtype=tf.string) + values = tf.ones(shape=keys.shape, dtype=tf.int64) + init = tf.lookup.KeyValueTensorInitializer( + keys=keys, values=values, key_dtype=tf.string, value_dtype=tf.int64) + ids_tables = [ + tf.lookup.StaticHashTable(init, default_value=0) + for split_num in range(num_splits) + ] + return ids_tables + + if loo_training: + train_ids = create_loo_ids_table( + meta_data_df, + initial_sample_size=reduced_datset_sz, + initial_sample_seed=0, + loo_id=loo_id, + num_splits=num_splits, + train=True) + val_ids = create_loo_ids_table( + meta_data_df, + initial_sample_size=reduced_datset_sz, + initial_sample_seed=0, + loo_id=loo_id, + num_splits=num_splits, + train=False) + else: + train_ids = create_ids_table( + meta_data_df, + initial_sample_size=reduced_datset_sz, + initial_sample_seed=initial_sample_seed, + split_proportion=split_proportion, + num_splits=num_splits, + split_seed=split_seed, + train=True) + val_ids = create_ids_table( + meta_data_df, + initial_sample_size=reduced_datset_sz, + initial_sample_seed=initial_sample_seed, + split_proportion=split_proportion, + num_splits=num_splits, + split_seed=split_seed, + train=False) + + train_splits = [ + train_ds.filter(filter_ids_fn(ids_tab)) for ids_tab in train_ids + ] + + val_splits = [train_ds.filter(filter_ids_fn(ids_tab)) for ids_tab in val_ids] train_sample = tfds.load( 'waterbirds_dataset', @@ -520,8 +629,15 @@ def get_waterbirds_dataset( as_supervised=is_training, with_info=False) - train_ds = gather_data_splits(list(range(num_splits)), train_splits) - val_ds = gather_data_splits(list(range(num_splits)), val_splits) + val_ds = tfds.load( + 'waterbirds_dataset', + split='validation', + data_dir=DATA_DIR, + builder_kwargs=builder_kwargs, + try_gcs=False, + as_supervised=is_training, + with_info=False) + eval_datasets = { 'val': val_ds, 'test': test_ds, diff --git a/experimental/shoshin/train_tf.py b/experimental/shoshin/train_tf.py index 25f63b212..687553153 100644 --- a/experimental/shoshin/train_tf.py +++ b/experimental/shoshin/train_tf.py @@ -97,10 +97,17 @@ def main(_) -> None: logging.info('Running Round %d of Training.', config.round_idx) if config.round_idx == 0: # If initial round of sampling, sample randomly initial_sample_proportion - dataloader = dataset_builder(config.data.num_splits, - config.data.initial_sample_proportion, - config.data.subgroup_ids, - config.data.subgroup_proportions) + dataloader = dataset_builder( + config.data.num_splits, + config.data.initial_sample_proportion, + config.data.subgroup_ids, + config.data.subgroup_proportions, + initial_sample_seed=config.data.initial_sample_seed, + split_seed=config.data.split_seed, + split_proportion=config.data.split_proportion, + loo_training=config.data.loo_training, + loo_id=config.data.loo_id, + ) else: # If latter round, keep track of split generated in last round of active # sampling @@ -127,17 +134,14 @@ def main(_) -> None: logging.info( 'Error: Bias table not found') return - # Training a single model on a combination of data splits. - included_splits_idx = [int(i) for i in config.data.included_splits_idx] - train_ds = data.gather_data_splits(included_splits_idx, - dataloader.train_splits) - val_ds = data.gather_data_splits(included_splits_idx, - dataloader.val_splits) + # Training a single model on a specific + train_ds = dataloader.train_splits[int(config.data.split_id)] + val_ds = dataloader.val_splits[int(config.data.split_id)] dataloader.train_ds = train_ds dataloader.eval_ds['val'] = val_ds experiment_name = 'stage_2' if config.train_bias else 'stage_1' - _ = train_tf_lib.train_and_evaluate( + model = train_tf_lib.train_and_evaluate( train_as_ensemble=config.train_stage_2_as_ensemble, dataloader=dataloader, model_params=model_params, @@ -150,6 +154,9 @@ def main(_) -> None: ensemble_dir=FLAGS.ensemble_dir, example_id_to_bias_table=example_id_to_bias_table) + if config.keep_individual_predictions: + train_tf_lib.create_predictions_table(model, dataloader, output_dir) + if __name__ == '__main__': app.run(main) diff --git a/experimental/shoshin/train_tf_lib.py b/experimental/shoshin/train_tf_lib.py index 3c0ce5223..c9a457b84 100644 --- a/experimental/shoshin/train_tf_lib.py +++ b/experimental/shoshin/train_tf_lib.py @@ -27,6 +27,7 @@ from absl import logging import numpy as np +import pandas as pd import tensorflow as tf import data # local file import from experimental.shoshin import models # local file import from experimental.shoshin @@ -535,3 +536,40 @@ def train_and_evaluate( example_id_to_bias_table=example_id_to_bias_table) evaluate_model(two_head_model, checkpoint_dir, dataloader.eval_ds) return two_head_model + + +def create_predictions_table(model, dataloader, save_dir, save_table=True): + """Generates a lookup table mapping example ID and training set membership to prediction. + + Args: + model: Trained model + dataloader: Dataclass object containing training and validation data. + save_dir: Directory in which predictions table will be saved as CSV. + save_table: Boolean for whether or not to save table. + Returns: + A pandas dataframe mapping example ID topredictions and group membership. + """ + + table_name = 'predictions_table' + dfs = [] + for is_train, ds in enumerate([dataloader.eval_ds['val'], + dataloader.train_ds]): + labels = list( + ds.map( + lambda feats, label, example_id: label).as_numpy_iterator()) + labels = np.concatenate(labels) + predictions = model.predict(ds) + example_ids = list(ds.map( + lambda feats, label, example_id: example_id).as_numpy_iterator()) + example_ids = np.concatenate(example_ids) + in_train = is_train*np.ones_like(example_ids) + dict_values = {'example_id': example_ids} + dict_values['in_train'] = in_train + dict_values['predictions_label'] = predictions['main'][..., 1] + dfs.append(pd.DataFrame(dict_values)) + df = pd.concat(dfs) + if save_table: + df.to_csv(os.path.join(save_dir, table_name + '.csv'), index=False) + return df + +